cagpjax.policies
Modules:
-
base– -
block_sparse–Block-sparse policy.
-
lanczos–Lanczos-based policies.
-
orthogonalization– -
pseudoinput–Pseodo-input linear solver policy.
Classes:
-
AbstractBatchLinearSolverPolicy–Abstract base class for policies that product action matrices.
-
AbstractLinearSolverPolicy–Abstract base class for all linear solver policies.
-
BlockSparsePolicy–Block-sparse linear solver policy.
-
LanczosPolicy–Lanczos-based policy for eigenvalue decomposition approximation.
-
OrthogonalizationPolicy–Orthogonalization policy.
-
PseudoInputPolicy–Pseudo-input linear solver policy.
AbstractBatchLinearSolverPolicy
Bases: AbstractLinearSolverPolicy, ABC
Abstract base class for policies that product action matrices.
Methods:
-
to_actions–Compute all actions used to solve the linear system \(Ax=b\).
Attributes:
-
n_actions(int) –Number of actions in this policy.
to_actions
abstractmethod
Compute all actions used to solve the linear system \(Ax=b\).
For a matrix \(A\) with shape (n, n), the action matrix has shape
(n, n_actions).
Parameters:
-
(ALinearOperator) –Linear operator representing the linear system.
Returns:
-
LinearOperator–Linear operator representing the action matrix.
Source code in src/cagpjax/policies/base.py
AbstractLinearSolverPolicy
Bases: Module
Abstract base class for all linear solver policies.
Policies define actions used to solve a linear system \(A x = b\), where \(A\) is a square linear operator.
BlockSparsePolicy
BlockSparsePolicy(n_actions: int, n: int | None = None, nz_values: Float[Array, N] | Variable[Float[Array, N]] | None = None, key: PRNGKeyArray | None = None, **kwargs)
Bases: AbstractBatchLinearSolverPolicy
Block-sparse linear solver policy.
This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:
These are stacked and stored as a single trainable parameter nz_values.
Initialize the block sparse policy.
Parameters:
-
(n_actionsint) –Number of actions to use.
-
(nint | None, default:None) –Number of rows and columns of the full operator. Must be provided if
nz_valuesis not provided. -
(nz_valuesFloat[Array, N] | Variable[Float[Array, N]] | None, default:None) –Non-zero values of the block-diagonal sparse matrix (shape
(n,)). If not provided, random actions are sampled using the key if provided. -
(keyPRNGKeyArray | None, default:None) –Random key for sampling actions if
nz_valuesis not provided. -
–**kwargsAdditional keyword arguments for
jax.random.normal(e.g.dtype)
Methods:
-
to_actions–Convert to block diagonal sparse action operators.
Attributes:
-
n_actions(int) –Number of actions to be used.
Source code in src/cagpjax/policies/block_sparse.py
to_actions
Convert to block diagonal sparse action operators.
Parameters:
-
(ALinearOperator) –Linear operator (unused).
Returns:
-
BlockDiagonalSparse(LinearOperator) –Sparse action structure representing the blocks.
Source code in src/cagpjax/policies/block_sparse.py
LanczosPolicy
LanczosPolicy(n_actions: int | None, key: PRNGKeyArray | None = None, grad_rtol: float | None = 0.0)
Bases: AbstractBatchLinearSolverPolicy
Lanczos-based policy for eigenvalue decomposition approximation.
This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors
of the linear operator \(A\).
Attributes:
-
n_actions(int) –Number of Lanczos vectors/actions to compute.
-
key(PRNGKeyArray | None) –Random key for reproducible Lanczos iterations.
Initialize the Lanczos policy.
Parameters:
-
(n_actionsint | None) –Number of Lanczos vectors to compute.
-
(keyPRNGKeyArray | None, default:None) –Random key for initialization.
-
(grad_rtolfloat | None, default:0.0) –Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see
cagpjax.linalg.eighfor more details)
Methods:
-
to_actions–Compute action matrix.
Source code in src/cagpjax/policies/lanczos.py
to_actions
Compute action matrix.
Parameters:
-
(ALinearOperator) –Symmetric linear operator representing the linear system.
Returns:
-
LinearOperator–Linear operator containing the Lanczos vectors as columns.
Source code in src/cagpjax/policies/lanczos.py
OrthogonalizationPolicy
OrthogonalizationPolicy(base_policy: AbstractBatchLinearSolverPolicy, method: OrthogonalizationMethod = OrthogonalizationMethod.QR, n_reortho: int = 0)
Bases: AbstractBatchLinearSolverPolicy
Orthogonalization policy.
This policy orthogonalizes (if necessary) the action operator produced by the base policy.
Parameters:
-
(base_policyAbstractBatchLinearSolverPolicy) –The base policy that produces the action operator to be orthogonalized.
-
(methodOrthogonalizationMethod, default:QR) –The method to use for orthogonalization.
-
(n_reorthoint, default:0) –The number of times to re-orthogonalize each column. Reorthogonalizing once is generally sufficient to improve orthogonality for Gram-Schmidt variants (see e.g. 10.1007/s00211-005-0615-4).
Source code in src/cagpjax/policies/orthogonalization.py
PseudoInputPolicy
PseudoInputPolicy(pseudo_inputs: Float[Array, 'M D'] | Parameter[Float[Array, 'M D']], train_inputs_or_dataset: Float[Array, 'N D'] | Dataset, kernel: AbstractKernel)
Bases: AbstractBatchLinearSolverPolicy
Pseudo-input linear solver policy.
This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.
Parameters:
-
(pseudo_inputsFloat[Array, 'M D'] | Parameter[Float[Array, 'M D']]) –Pseudo-inputs for the kernel. If wrapped as a
gpjax.parameters.Parameter, they will be treated as trainable. -
–train_inputsTraining inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.
-
(kernelAbstractKernel) –Kernel for the GP prior. It must be able to take
train_inputsandpseudo_inputsas arguments to itscross_covariancemethod.
Note
When training with many pseudo-inputs, it is common for the cross-covariance matrix to
become poorly conditioned. Performance can be significantly improved by orthogonalizing
the actions using an OrthogonalizationPolicy.