Skip to content

cagpjax.solvers

Modules:

  • base

    Base classes for linear solvers and methods.

  • cholesky

    Linear solvers based on Cholesky decomposition.

  • pseudoinverse

Classes:

  • AbstractLinearSolver

    Base class for linear solvers.

  • Cholesky

    Solve a linear system using the Cholesky decomposition.

  • PseudoInverse

    Solve a linear system using the Moore-Penrose pseudoinverse.

AbstractLinearSolver

Bases: Module, Generic[_LinearSolverState]

Base class for linear solvers.

These solvers are used to exactly or approximately solve the linear system \(Ax = b\) for \(x\), where \(A\) is a positive (semi-)definite (PSD) linear operator.

Methods:

  • init

    Construct a solver state.

  • inv_congruence_transform

    Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

  • inv_quad

    Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

  • logdet

    Compute the logarithm of the (pseudo-)determinant of \(A\).

  • solve

    Compute a solution to the linear system \(Ax = b\).

  • trace_solve

    Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).

  • unwhiten

    Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).

init abstractmethod

init(A: LinearOperator) -> _LinearSolverState

Construct a solver state.

Parameters:

Returns:

  • _LinearSolverState

    State of the linear solver, which stores any necessary intermediate values.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def init(self, A: LinearOperator) -> _LinearSolverState:
    """Construct a solver state.

    Arguments:
        A: Positive (semi-)definite linear operator.

    Returns:
        State of the linear solver, which stores any necessary intermediate values.
    """
    pass

inv_congruence_transform abstractmethod

inv_congruence_transform(state: _LinearSolverState, B: LinearOperator | Float[Array, 'N K']) -> LinearOperator | Float[Array, 'K K']

Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • B

    (LinearOperator | Float[Array, 'N K']) –

    Linear operator or array to be applied.

Returns:

  • LinearOperator | Float[Array, 'K K']

    Linear operator or array resulting from the congruence transform.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def inv_congruence_transform(
    self, state: _LinearSolverState, B: LinearOperator | Float[Array, "N K"]
) -> LinearOperator | Float[Array, "K K"]:
    """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        B: Linear operator or array to be applied.

    Returns:
        Linear operator or array resulting from the congruence transform.
    """
    pass

inv_quad

inv_quad(state: _LinearSolverState, b: Float[Array, N]) -> ScalarFloat

Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • b

    (Float[Array, N]) –

    Right-hand side of the linear system.

Source code in src/cagpjax/solvers/base.py
def inv_quad(
    self, state: _LinearSolverState, b: Float[Array, "N #1"]
) -> ScalarFloat:
    """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        b: Right-hand side of the linear system.
    """
    return self.inv_congruence_transform(state, b[:, None]).squeeze()

logdet abstractmethod

logdet(state: _LinearSolverState) -> ScalarFloat

Compute the logarithm of the (pseudo-)determinant of \(A\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def logdet(self, state: _LinearSolverState) -> ScalarFloat:
    """Compute the logarithm of the (pseudo-)determinant of $A$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
    """
    pass

solve abstractmethod

solve(state: _LinearSolverState, b: Float[Array, N]) -> Float[Array, N]

Compute a solution to the linear system \(Ax = b\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • b

    (Float[Array, N]) –

    Right-hand side of the linear system.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def solve(
    self, state: _LinearSolverState, b: Float[Array, "N #K"]
) -> Float[Array, "N #K"]:
    """Compute a solution to the linear system $Ax = b$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        b: Right-hand side of the linear system.
    """
    pass

trace_solve abstractmethod

trace_solve(state: _LinearSolverState, state_other: _LinearSolverState) -> ScalarFloat

Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver obtained by applying init to an operator representing \(A\)

  • state_other

    (_LinearSolverState) –

    Another state obtained by applying init to an operator representing \(B\).

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def trace_solve(
    self, state: _LinearSolverState, state_other: _LinearSolverState
) -> ScalarFloat:
    r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $A$ and $B$.

    Arguments:
        state: State of the linear solver obtained by applying [`init`][..init] to an operator
            representing $A$
        state_other: Another state obtained by applying `init` to an operator representing $B$.
    """
    pass

unwhiten abstractmethod

unwhiten(state: _LinearSolverState, z: Float[Array, N]) -> Float[Array, N]

Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).

Parameters:

  • state

    (_LinearSolverState) –

    State of the linear solver returned by init.

  • z

    (Float[Array, N]) –

    IID standard normal vector.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def unwhiten(
    self, state: _LinearSolverState, z: Float[Array, "N #K"]
) -> Float[Array, "N #K"]:
    """Given an IID standard normal vector $z$, return $x$ with covariance $A$.

    Arguments:
        state: State of the linear solver returned by [`init`][..init].
        z: IID standard normal vector.
    """
    pass

Cholesky

Cholesky(jitter: ScalarFloat | None = None)

Bases: AbstractLinearSolver[CholeskyState]

Solve a linear system using the Cholesky decomposition.

Due to numerical imprecision, Cholesky factorization may fail even for positive-definite \(A\). Optionally, a small amount of jitter (\(\epsilon\)) can be added to \(A\) to ensure positive-definiteness. Note that the resulting system solved is slightly different from the original system.

Attributes:

  • jitter (ScalarFloat | None) –

    Small amount of jitter to add to \(A\) to ensure positive-definiteness.

Source code in src/cagpjax/solvers/cholesky.py
def __init__(self, jitter: ScalarFloat | None = None):
    self.jitter = jitter

PseudoInverse

PseudoInverse(rtol: ScalarFloat | None = None, grad_rtol: float | None = None, alg: Algorithm = Eigh())

Bases: AbstractLinearSolver[PseudoInverseState]

Solve a linear system using the Moore-Penrose pseudoinverse.

This solver computes the least-squares solution \(x = A^+ b\) for any \(A\), where \(A^+\) is the Moore-Penrose pseudoinverse. This is equivalent to the exact solution for non-singular \(A\) but generalizes to singular \(A\) and improves stability for almost-singular \(A\); note, however, that if the rank of \(A\) is dependent on hyperparameters being optimized, because the pseudoinverse is discontinuous, the optimization problem may be ill-posed.

Note that if \(A\) is (almost-)degenerate (some eigenvalues repeat), then the gradient of its solves in JAX may be non-computable or numerically unstable (see jax#669). For degenerate operators, it may be necessary to increase grad_rtol to improve stability of gradients. See cagpjax.linalg.eigh for more details.

Attributes:

  • rtol (ScalarFloat | None) –

    Specifies the cutoff for small eigenvalues. Eigenvalues smaller than rtol * largest_nonzero_eigenvalue are treated as zero. The default is determined based on the floating point precision of the dtype of the operator (see jax.numpy.linalg.pinv).

  • grad_rtol (float | None) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct.

  • alg (Algorithm) –

    Algorithm for eigenvalue decomposition passed to cagpjax.linalg.eigh.

Source code in src/cagpjax/solvers/pseudoinverse.py
def __init__(
    self,
    rtol: ScalarFloat | None = None,
    grad_rtol: float | None = None,
    alg: cola.linalg.Algorithm = Eigh(),
):
    self.rtol = rtol
    self.grad_rtol = grad_rtol
    self.alg = alg