cagpjax.solvers
Modules:
-
base–Base classes for linear solvers and methods.
-
cholesky–Linear solvers based on Cholesky decomposition.
-
pseudoinverse–
Classes:
-
AbstractLinearSolver–Base class for linear solvers.
-
Cholesky–Solve a linear system using the Cholesky decomposition.
-
PseudoInverse–Solve a linear system using the Moore-Penrose pseudoinverse.
AbstractLinearSolver
Bases: Module, Generic[_LinearSolverState]
Base class for linear solvers.
These solvers are used to exactly or approximately solve the linear system \(Ax = b\) for \(x\), where \(A\) is a positive (semi-)definite (PSD) linear operator.
Methods:
-
init–Construct a solver state.
-
inv_congruence_transform–Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).
-
inv_quad–Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).
-
logdet–Compute the logarithm of the (pseudo-)determinant of \(A\).
-
solve–Compute a solution to the linear system \(Ax = b\).
-
trace_solve–Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).
-
unwhiten–Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).
init
abstractmethod
Construct a solver state.
Parameters:
-
(ALinearOperator) –Positive (semi-)definite linear operator.
Returns:
-
_LinearSolverState–State of the linear solver, which stores any necessary intermediate values.
Source code in src/cagpjax/solvers/base.py
inv_congruence_transform
abstractmethod
inv_congruence_transform(state: _LinearSolverState, B: LinearOperator | Float[Array, 'N K']) -> LinearOperator | Float[Array, 'K K']
Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver returned by
init. -
(BLinearOperator | Float[Array, 'N K']) –Linear operator or array to be applied.
Returns:
-
LinearOperator | Float[Array, 'K K']–Linear operator or array resulting from the congruence transform.
Source code in src/cagpjax/solvers/base.py
inv_quad
Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver returned by
init. -
(bFloat[Array, N]) –Right-hand side of the linear system.
Source code in src/cagpjax/solvers/base.py
logdet
abstractmethod
Compute the logarithm of the (pseudo-)determinant of \(A\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver returned by
init.
solve
abstractmethod
Compute a solution to the linear system \(Ax = b\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver returned by
init. -
(bFloat[Array, N]) –Right-hand side of the linear system.
Source code in src/cagpjax/solvers/base.py
trace_solve
abstractmethod
Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(A\) and \(B\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver obtained by applying
initto an operator representing \(A\) -
(state_other_LinearSolverState) –Another state obtained by applying
initto an operator representing \(B\).
Source code in src/cagpjax/solvers/base.py
unwhiten
abstractmethod
Given an IID standard normal vector \(z\), return \(x\) with covariance \(A\).
Parameters:
-
(state_LinearSolverState) –State of the linear solver returned by
init. -
(zFloat[Array, N]) –IID standard normal vector.
Source code in src/cagpjax/solvers/base.py
Cholesky
Bases: AbstractLinearSolver[CholeskyState]
Solve a linear system using the Cholesky decomposition.
Due to numerical imprecision, Cholesky factorization may fail even for
positive-definite \(A\). Optionally, a small amount of jitter (\(\epsilon\)) can
be added to \(A\) to ensure positive-definiteness. Note that the resulting system
solved is slightly different from the original system.
Attributes:
-
jitter(ScalarFloat | None) –Small amount of jitter to add to \(A\) to ensure positive-definiteness.
Source code in src/cagpjax/solvers/cholesky.py
PseudoInverse
PseudoInverse(rtol: ScalarFloat | None = None, grad_rtol: float | None = None, alg: Algorithm = Eigh())
Bases: AbstractLinearSolver[PseudoInverseState]
Solve a linear system using the Moore-Penrose pseudoinverse.
This solver computes the least-squares solution \(x = A^+ b\) for any \(A\), where \(A^+\) is the Moore-Penrose pseudoinverse. This is equivalent to the exact solution for non-singular \(A\) but generalizes to singular \(A\) and improves stability for almost-singular \(A\); note, however, that if the rank of \(A\) is dependent on hyperparameters being optimized, because the pseudoinverse is discontinuous, the optimization problem may be ill-posed.
Note that if \(A\) is (almost-)degenerate (some eigenvalues repeat), then
the gradient of its solves in JAX may be non-computable or numerically unstable
(see jax#669).
For degenerate operators, it may be necessary to increase grad_rtol to improve
stability of gradients.
See cagpjax.linalg.eigh for more details.
Attributes:
-
rtol(ScalarFloat | None) –Specifies the cutoff for small eigenvalues. Eigenvalues smaller than
rtol * largest_nonzero_eigenvalueare treated as zero. The default is determined based on the floating point precision of the dtype of the operator (seejax.numpy.linalg.pinv). -
grad_rtol(float | None) –Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct.
-
alg(Algorithm) –Algorithm for eigenvalue decomposition passed to
cagpjax.linalg.eigh.