cagpjax
Computation-Aware Gaussian Processes for GPJax.
Modules:
-
computations–Kernel computation methods.
-
distributions– -
linalg–Linear algebra functions.
-
models–Gaussian process models.
-
operators–Custom linear operators.
-
policies– -
solvers–
Classes:
-
BlockDiagonalSparse–Block-diagonal sparse linear operator.
-
BlockSparsePolicy–Block-sparse linear solver policy.
-
ComputationAwareGP–Computation-aware Gaussian Process model.
-
LanczosPolicy–Lanczos-based policy for eigenvalue decomposition approximation.
BlockDiagonalSparse
Bases: LinearOperator
Block-diagonal sparse linear operator.
This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.
Parameters:
-
(nz_valuesFloat[Array, N]) –Non-zero values to be distributed across diagonal blocks.
-
(n_blocksint) –Number of diagonal blocks in the matrix.
Examples
>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
[2., 0., 0.],
[0., 3., 0.],
[0., 4., 0.],
[0., 0., 5.],
[0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
BlockSparsePolicy
BlockSparsePolicy(n_actions: int, n: int | None = None, nz_values: Float[Array, N] | Variable[Float[Array, N]] | None = None, key: PRNGKeyArray | None = None, **kwargs)
Bases: AbstractBatchLinearSolverPolicy
Block-sparse linear solver policy.
This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:
These are stacked and stored as a single trainable parameter nz_values.
Initialize the block sparse policy.
Parameters:
-
(n_actionsint) –Number of actions to use.
-
(nint | None, default:None) –Number of rows and columns of the full operator. Must be provided if
nz_valuesis not provided. -
(nz_valuesFloat[Array, N] | Variable[Float[Array, N]] | None, default:None) –Non-zero values of the block-diagonal sparse matrix (shape
(n,)). If not provided, random actions are sampled using the key if provided. -
(keyPRNGKeyArray | None, default:None) –Random key for sampling actions if
nz_valuesis not provided. -
–**kwargsAdditional keyword arguments for
jax.random.normal(e.g.dtype)
Methods:
-
to_actions–Convert to block diagonal sparse action operators.
Attributes:
-
n_actions(int) –Number of actions to be used.
Source code in src/cagpjax/policies/block_sparse.py
to_actions
Convert to block diagonal sparse action operators.
Parameters:
-
(ALinearOperator) –Linear operator (unused).
Returns:
-
BlockDiagonalSparse(LinearOperator) –Sparse action structure representing the blocks.
Source code in src/cagpjax/policies/block_sparse.py
ComputationAwareGP
ComputationAwareGP(posterior: ConjugatePosterior, policy: AbstractBatchLinearSolverPolicy, solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-06))
Bases: Module, Generic[_LinearSolverState]
Computation-aware Gaussian Process model.
This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.
Attributes:
-
posterior(ConjugatePosterior) –The original (exact) posterior.
-
policy(AbstractBatchLinearSolverPolicy) –The batch linear solver policy.
-
solver(AbstractLinearSolver[_LinearSolverState]) –The linear solver method to use for solving linear systems with positive semi-definite operators.
Notes
- Only single-output models are currently supported.
Initialize the Computation-Aware GP model.
Parameters:
-
(posteriorConjugatePosterior) –GPJax conjugate posterior.
-
(policyAbstractBatchLinearSolverPolicy) –The batch linear solver policy that defines the subspace into which the data is projected.
-
(solverAbstractLinearSolver[_LinearSolverState], default:Cholesky(1e-06)) –The linear solver method to use for solving linear systems with positive semi-definite operators.
Methods:
-
elbo–Compute the evidence lower bound.
-
init–Compute the state of the conditioned GP posterior.
-
predict–Compute the predictive distribution of the GP at the test inputs.
-
prior_kl–Compute KL divergence between CaGP posterior and GP prior..
-
variational_expectation–Compute the variational expectation.
Source code in src/cagpjax/models/cagp.py
elbo
Compute the evidence lower bound.
Computes the evidence lower bound (ELBO) under this model's variational distribution.
Note
This should be used instead of gpjax.objectives.elbo
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
ScalarFloat–ELBO value (scalar).
Source code in src/cagpjax/models/cagp.py
init
Compute the state of the conditioned GP posterior.
Parameters:
-
(train_dataDataset) –The training data used to fit the GP.
Returns:
-
state(ComputationAwareGPState[_LinearSolverState]) –State of the conditioned CaGP posterior, which stores any necessary intermediate values for prediction and computing objectives.
Source code in src/cagpjax/models/cagp.py
predict
predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs.
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init -
(test_inputsFloat[Array, 'N D'] | None, default:None) –The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.
Returns:
-
GaussianDistribution(GaussianDistribution) –The predictive distribution of the GP at the test inputs.
Source code in src/cagpjax/models/cagp.py
prior_kl
Compute KL divergence between CaGP posterior and GP prior..
Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
ScalarFloat–KL divergence value (scalar).
Source code in src/cagpjax/models/cagp.py
variational_expectation
Compute the variational expectation.
Compute the pointwise expected log-likelihood under the variational distribution.
Note
This should be used instead of gpjax.objectives.variational_expectation
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
expectation(Float[Array, N]) –The pointwise expected log-likelihood under the variational distribution.
Source code in src/cagpjax/models/cagp.py
LanczosPolicy
LanczosPolicy(n_actions: int | None, key: PRNGKeyArray | None = None, grad_rtol: float | None = 0.0)
Bases: AbstractBatchLinearSolverPolicy
Lanczos-based policy for eigenvalue decomposition approximation.
This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors
of the linear operator \(A\).
Attributes:
-
n_actions(int) –Number of Lanczos vectors/actions to compute.
-
key(PRNGKeyArray | None) –Random key for reproducible Lanczos iterations.
Initialize the Lanczos policy.
Parameters:
-
(n_actionsint | None) –Number of Lanczos vectors to compute.
-
(keyPRNGKeyArray | None, default:None) –Random key for initialization.
-
(grad_rtolfloat | None, default:0.0) –Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see
cagpjax.linalg.eighfor more details)
Methods:
-
to_actions–Compute action matrix.
Source code in src/cagpjax/policies/lanczos.py
to_actions
Compute action matrix.
Parameters:
-
(ALinearOperator) –Symmetric linear operator representing the linear system.
Returns:
-
LinearOperator–Linear operator containing the Lanczos vectors as columns.