Skip to content

cagpjax.solvers.base

Base classes for linear solvers and methods.

Classes:

AbstractLinearSolver

Bases: Module, Generic[_LinearSolverState]

Base class for linear solvers.

These solvers are used to exactly or approximately solve the linear system \(Ax = b\) for \(x\), where \(A\) is a positive (semi-)definite (PSD) linear operator.

Methods:

  • init

    Construct a solver state.

  • inv_congruence_transform

    Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

  • inv_quad

    Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

  • logdet

    Compute the logarithm of the (pseudo-)determinant of \(A\).

  • solve

    Compute a solution to the linear system \(Ax = b\).

  • trace_solve

    Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).

  • unwhiten

    Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).

init abstractmethod

init(A: LinearOperator) -> _LinearSolverState

Construct a solver state.

Parameters:

Returns:

  • _LinearSolverState

    State of the linear solver, which stores any necessary intermediate values.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def init(self, A: LinearOperator) -> _LinearSolverState:
    """Construct a solver state.

    Arguments:
        A: Positive (semi-)definite linear operator.

    Returns:
        State of the linear solver, which stores any necessary intermediate values.
    """
    pass

inv_congruence_transform abstractmethod

inv_congruence_transform(state: _LinearSolverState, B: LinearOperator | Float[Array, 'N K']) -> LinearOperator | Float[Array, 'K K']

Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • B

    (LinearOperator | Float[Array, 'N K']) –

    Linear operator or array to be applied.

Returns:

  • LinearOperator | Float[Array, 'K K']

    Linear operator or array resulting from the congruence transform.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def inv_congruence_transform(
    self, state: _LinearSolverState, B: LinearOperator | Float[Array, "N K"]
) -> LinearOperator | Float[Array, "K K"]:
    """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        B: Linear operator or array to be applied.

    Returns:
        Linear operator or array resulting from the congruence transform.
    """
    pass

inv_quad

inv_quad(state: _LinearSolverState, b: Float[Array, N]) -> ScalarFloat

Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • b

    (Float[Array, N]) –

    Right-hand side of the linear system.

Source code in src/cagpjax/solvers/base.py
def inv_quad(
    self, state: _LinearSolverState, b: Float[Array, "N #1"]
) -> ScalarFloat:
    """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        b: Right-hand side of the linear system.
    """
    return self.inv_congruence_transform(state, b[:, None]).squeeze()

logdet abstractmethod

logdet(state: _LinearSolverState) -> ScalarFloat

Compute the logarithm of the (pseudo-)determinant of \(A\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def logdet(self, state: _LinearSolverState) -> ScalarFloat:
    """Compute the logarithm of the (pseudo-)determinant of $A$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
    """
    pass

solve abstractmethod

solve(state: _LinearSolverState, b: Float[Array, N]) -> Float[Array, N]

Compute a solution to the linear system \(Ax = b\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • b

    (Float[Array, N]) –

    Right-hand side of the linear system.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def solve(
    self, state: _LinearSolverState, b: Float[Array, "N #K"]
) -> Float[Array, "N #K"]:
    """Compute a solution to the linear system $Ax = b$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        b: Right-hand side of the linear system.
    """
    pass

trace_solve abstractmethod

trace_solve(state: _LinearSolverState, state_other: _LinearSolverState) -> ScalarFloat

Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver obtained by applying init to an operator representing \(A\)

  • state_other

    (_LinearSolverState) –

    Another state obtained by applying init to an operator representing \(B\).

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def trace_solve(
    self, state: _LinearSolverState, state_other: _LinearSolverState
) -> ScalarFloat:
    r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $A$ and $B$.

    Arguments:
        state: State of the linear solver obtained by applying [`init`][..init] to an operator
            representing $A$
        state_other: Another state obtained by applying `init` to an operator representing $B$.
    """
    pass

unwhiten abstractmethod

unwhiten(state: _LinearSolverState, z: Float[Array, N]) -> Float[Array, N]

Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • z

    (Float[Array, N]) –

    IID standard normal vector.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def unwhiten(
    self, state: _LinearSolverState, z: Float[Array, "N #K"]
) -> Float[Array, "N #K"]:
    """Given an IID standard normal vector $z$, return $x$ with covariance $A$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        z: IID standard normal vector.
    """
    pass