Skip to content

cagpjax.solvers

AbstractLinearSolver

Bases: Module

Base class for linear solvers.

These solvers are used to exactly or approximately solve the linear system \(Ax = b\) for \(x\), where \(A\) is a positive (semi-)definite (PSD) linear operator.

Solvers should always be constructed by a AbstractLinearSolverMethod.

Source code in src/cagpjax/solvers/base.py
class AbstractLinearSolver(nnx.Module):
    """
    Base class for linear solvers.

    These solvers are used to exactly or approximately solve the linear
    system $Ax = b$ for $x$, where $A$ is a positive (semi-)definite (PSD)
    linear operator.

    Solvers should always be constructed by a `AbstractLinearSolverMethod`.
    """

    @abstractmethod
    def solve(self, b: Float[Array, "N #K"]) -> Float[Array, "N #K"]:
        """Compute a solution to the linear system $Ax = b$.

        Arguments:
            b: Right-hand side of the linear system.
        """
        pass

    @abstractmethod
    def logdet(self) -> ScalarFloat:
        """Compute the logarithm of the (pseudo-)determinant of $A$."""
        pass

    @abstractmethod
    def inv_congruence_transform(
        self, B: LinearOperator | Float[Array, "N K"]
    ) -> LinearOperator | Float[Array, "K K"]:
        """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

        Arguments:
            B: Linear operator or array to be applied.

        Returns:
            Linear operator or array resulting from the congruence transform.
        """
        pass

    def inv_quad(self, b: Float[Array, "N #1"]) -> ScalarFloat:
        """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

        Arguments:
            b: Right-hand side of the linear system.
        """
        return self.inv_congruence_transform(b[:, None]).squeeze()

    @abstractmethod
    def trace_solve(self, B: Self) -> ScalarFloat:
        r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $B$.

        Arguments:
            B: An `AbstractLinearSolver` of the same type as `self` representing
                the PSD linear operator $B$.
        """
        pass

inv_congruence_transform(B) abstractmethod

Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

Parameters:

Name Type Description Default
B LinearOperator | Float[Array, 'N K']

Linear operator or array to be applied.

required

Returns:

Type Description
LinearOperator | Float[Array, 'K K']

Linear operator or array resulting from the congruence transform.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def inv_congruence_transform(
    self, B: LinearOperator | Float[Array, "N K"]
) -> LinearOperator | Float[Array, "K K"]:
    """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

    Arguments:
        B: Linear operator or array to be applied.

    Returns:
        Linear operator or array resulting from the congruence transform.
    """
    pass

inv_quad(b)

Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

Parameters:

Name Type Description Default
b Float[Array, N]

Right-hand side of the linear system.

required
Source code in src/cagpjax/solvers/base.py
def inv_quad(self, b: Float[Array, "N #1"]) -> ScalarFloat:
    """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

    Arguments:
        b: Right-hand side of the linear system.
    """
    return self.inv_congruence_transform(b[:, None]).squeeze()

logdet() abstractmethod

Compute the logarithm of the (pseudo-)determinant of \(A\).

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def logdet(self) -> ScalarFloat:
    """Compute the logarithm of the (pseudo-)determinant of $A$."""
    pass

solve(b) abstractmethod

Compute a solution to the linear system \(Ax = b\).

Parameters:

Name Type Description Default
b Float[Array, N]

Right-hand side of the linear system.

required
Source code in src/cagpjax/solvers/base.py
@abstractmethod
def solve(self, b: Float[Array, "N #K"]) -> Float[Array, "N #K"]:
    """Compute a solution to the linear system $Ax = b$.

    Arguments:
        b: Right-hand side of the linear system.
    """
    pass

trace_solve(B) abstractmethod

Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(B\).

Parameters:

Name Type Description Default
B Self

An AbstractLinearSolver of the same type as self representing the PSD linear operator \(B\).

required
Source code in src/cagpjax/solvers/base.py
@abstractmethod
def trace_solve(self, B: Self) -> ScalarFloat:
    r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $B$.

    Arguments:
        B: An `AbstractLinearSolver` of the same type as `self` representing
            the PSD linear operator $B$.
    """
    pass

AbstractLinearSolverMethod

Bases: Module

Base class for linear solver methods.

These methods are used to construct AbstractLinearSolver instances.

Source code in src/cagpjax/solvers/base.py
class AbstractLinearSolverMethod(nnx.Module):
    """
    Base class for linear solver methods.

    These methods are used to construct `AbstractLinearSolver` instances.
    """

    @abstractmethod
    def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
        """Construct a solver from the positive (semi-)definite linear operator."""
        pass

__call__(A) abstractmethod

Construct a solver from the positive (semi-)definite linear operator.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
    """Construct a solver from the positive (semi-)definite linear operator."""
    pass

Cholesky

Bases: AbstractLinearSolverMethod

Solve a linear system using the Cholesky decomposition.

Due to numerical imprecision, Cholesky factorization may fail even for positive-definite \(A\). Optionally, a small amount of jitter (\(\epsilon\)) can be added to \(A\) to ensure positive-definiteness. Note that the resulting system solved is slightly different from the original system.

Attributes:

Name Type Description
jitter ScalarFloat | None

Small amount of jitter to add to \(A\) to ensure positive-definiteness.

Source code in src/cagpjax/solvers/cholesky.py
class Cholesky(AbstractLinearSolverMethod):
    """
    Solve a linear system using the Cholesky decomposition.

    Due to numerical imprecision, Cholesky factorization may fail even for
    positive-definite $A$. Optionally, a small amount of `jitter` ($\\epsilon$) can
    be added to $A$ to ensure positive-definiteness. Note that the resulting system
    solved is slightly different from the original system.

    Attributes:
        jitter: Small amount of jitter to add to $A$ to ensure positive-definiteness.
    """

    jitter: ScalarFloat | None

    def __init__(self, jitter: ScalarFloat | None = None):
        self.jitter = jitter

    @override
    def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
        return CholeskySolver(A, jitter=self.jitter)

PseudoInverse

Bases: AbstractLinearSolverMethod

Solve a linear system using the Moore-Penrose pseudoinverse.

This solver computes the least-squares solution \(x = A^+ b\) for any \(A\), where \(A^+\) is the Moore-Penrose pseudoinverse. This is equivalent to the exact solution for non-singular \(A\) but generalizes to singular \(A\) and improves stability for almost-singular \(A\); note, however, that if the rank of \(A\) is dependent on hyperparameters being optimized, because the pseudoinverse is discontinuous, the optimization problem may be ill-posed.

Note that if \(A\) is (almost-)degenerate (some eigenvalues repeat), then the gradient of its solves in JAX may be non-computable or numerically unstable (see jax#669). For degenerate operators, it may be necessary to increase grad_rtol to improve stability of gradients. See cagpjax.linalg.eigh for more details.

Attributes:

Name Type Description
rtol ScalarFloat | None

Specifies the cutoff for small eigenvalues. Eigenvalues smaller than rtol * largest_nonzero_eigenvalue are treated as zero. The default is determined based on the floating point precision of the dtype of the operator (see jax.numpy.linalg.pinv).

grad_rtol float | None

Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct.

alg Algorithm

Algorithm for eigenvalue decomposition passed to cagpjax.linalg.eigh.

Source code in src/cagpjax/solvers/pseudoinverse.py
class PseudoInverse(AbstractLinearSolverMethod):
    """
    Solve a linear system using the Moore-Penrose pseudoinverse.

    This solver computes the least-squares solution $x = A^+ b$ for any $A$,
    where $A^+$ is the Moore-Penrose pseudoinverse. This is equivalent to
    the exact solution for non-singular $A$ but generalizes to singular $A$
    and improves stability for almost-singular $A$; note, however, that if the
    rank of $A$ is dependent on hyperparameters being optimized, because the
    pseudoinverse is discontinuous, the optimization problem may be ill-posed.

    Note that if $A$ is (almost-)degenerate (some eigenvalues repeat), then
    the gradient of its solves in JAX may be non-computable or numerically unstable
    (see [jax#669](https://github.com/jax-ml/jax/issues/669)).
    For degenerate operators, it may be necessary to increase `grad_rtol` to improve
    stability of gradients.
    See [`cagpjax.linalg.eigh`][] for more details.

    Attributes:
        rtol: Specifies the cutoff for small eigenvalues.
              Eigenvalues smaller than `rtol * largest_nonzero_eigenvalue` are treated as zero.
              The default is determined based on the floating point precision of the dtype
              of the operator (see [`jax.numpy.linalg.pinv`][]).
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
        alg: Algorithm for eigenvalue decomposition passed to [`cagpjax.linalg.eigh`][].
    """

    rtol: ScalarFloat | None
    grad_rtol: float | None
    alg: cola.linalg.Algorithm

    def __init__(
        self,
        rtol: ScalarFloat | None = None,
        grad_rtol: float | None = None,
        alg: cola.linalg.Algorithm = Eigh(),
    ):
        self.rtol = rtol
        self.grad_rtol = grad_rtol
        self.alg = alg

    @override
    def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
        return PseudoInverseSolver(
            A, rtol=self.rtol, grad_rtol=self.grad_rtol, alg=self.alg
        )