Skip to content

cagpjax.solvers.pseudoinverse

Classes:

  • PseudoInverse

    Solve a linear system using the Moore-Penrose pseudoinverse.

PseudoInverse

Bases: AbstractLinearSolver[PseudoInverseState]

Solve a linear system using the Moore-Penrose pseudoinverse.

This solver computes the least-squares solution \(x = A^+ b\) for any \(A\), where \(A^+\) is the Moore-Penrose pseudoinverse. This is equivalent to the exact solution for non-singular \(A\) but generalizes to singular \(A\) and improves stability for almost-singular \(A\); note, however, that if the rank of \(A\) is dependent on hyperparameters being optimized, because the pseudoinverse is discontinuous, the optimization problem may be ill-posed.

Note that if \(A\) is (almost-)degenerate (some eigenvalues repeat), then the gradient of its solves in JAX may be non-computable or numerically unstable (see jax#669). For degenerate operators, it may be necessary to increase grad_rtol to improve stability of gradients. See cagpjax.linalg.eigh for more details.

Attributes:

  • rtol (ScalarFloat | None) –

    Specifies the cutoff for small eigenvalues. Eigenvalues smaller than rtol * largest_nonzero_eigenvalue are treated as zero. The default is determined based on the floating point precision of the dtype of the operator (see jax.numpy.linalg.pinv).

  • grad_rtol (float | None) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If None (default), all eigenvalues are treated as distinct.

  • alg (Algorithm) –

    Algorithm for eigenvalue decomposition passed to cagpjax.linalg.eigh.