Skip to content

cagpjax.solvers.pseudoinverse

Classes:

  • PseudoInverse

    Solve a linear system using the Moore-Penrose pseudoinverse.

PseudoInverse

PseudoInverse(rtol: ScalarFloat | None = None, grad_rtol: float | None = None, alg: Algorithm = Eigh())

Bases: AbstractLinearSolver[PseudoInverseState]

Solve a linear system using the Moore-Penrose pseudoinverse.

This solver computes the least-squares solution \(x = A^+ b\) for any \(A\), where \(A^+\) is the Moore-Penrose pseudoinverse. This is equivalent to the exact solution for non-singular \(A\) but generalizes to singular \(A\) and improves stability for almost-singular \(A\); note, however, that if the rank of \(A\) is dependent on hyperparameters being optimized, because the pseudoinverse is discontinuous, the optimization problem may be ill-posed.

Note that if \(A\) is (almost-)degenerate (some eigenvalues repeat), then the gradient of its solves in JAX may be non-computable or numerically unstable (see jax#669). For degenerate operators, it may be necessary to increase grad_rtol to improve stability of gradients. See cagpjax.linalg.eigh for more details.

Attributes:

  • rtol (ScalarFloat | None) –

    Specifies the cutoff for small eigenvalues. Eigenvalues smaller than rtol * largest_nonzero_eigenvalue are treated as zero. The default is determined based on the floating point precision of the dtype of the operator (see jax.numpy.linalg.pinv).

  • grad_rtol (float | None) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct.

  • alg (Algorithm) –

    Algorithm for eigenvalue decomposition passed to cagpjax.linalg.eigh.

Source code in src/cagpjax/solvers/pseudoinverse.py
def __init__(
    self,
    rtol: ScalarFloat | None = None,
    grad_rtol: float | None = None,
    alg: cola.linalg.Algorithm = Eigh(),
):
    self.rtol = rtol
    self.grad_rtol = grad_rtol
    self.alg = alg