cagpjax.models.cagp
Computation-aware Gaussian Process models.
Classes:
-
ComputationAwareGP–Computation-aware Gaussian Process model.
-
ComputationAwareGPState–Projected quantities for computation-aware GP inference.
ComputationAwareGP
ComputationAwareGP(posterior: ConjugatePosterior, policy: AbstractBatchLinearSolverPolicy, solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-06))
Bases: Module, Generic[_LinearSolverState]
Computation-aware Gaussian Process model.
This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.
Attributes:
-
posterior(ConjugatePosterior) –The original (exact) posterior.
-
policy(AbstractBatchLinearSolverPolicy) –The batch linear solver policy.
-
solver(AbstractLinearSolver[_LinearSolverState]) –The linear solver method to use for solving linear systems with positive semi-definite operators.
Notes
- Only single-output models are currently supported.
Initialize the Computation-Aware GP model.
Parameters:
-
(posteriorConjugatePosterior) –GPJax conjugate posterior.
-
(policyAbstractBatchLinearSolverPolicy) –The batch linear solver policy that defines the subspace into which the data is projected.
-
(solverAbstractLinearSolver[_LinearSolverState], default:Cholesky(1e-06)) –The linear solver method to use for solving linear systems with positive semi-definite operators.
Methods:
-
elbo–Compute the evidence lower bound.
-
init–Compute the state of the conditioned GP posterior.
-
predict–Compute the predictive distribution of the GP at the test inputs.
-
prior_kl–Compute KL divergence between CaGP posterior and GP prior..
-
variational_expectation–Compute the variational expectation.
Source code in src/cagpjax/models/cagp.py
elbo
Compute the evidence lower bound.
Computes the evidence lower bound (ELBO) under this model's variational distribution.
Note
This should be used instead of gpjax.objectives.elbo
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
ScalarFloat–ELBO value (scalar).
Source code in src/cagpjax/models/cagp.py
init
Compute the state of the conditioned GP posterior.
Parameters:
-
(train_dataDataset) –The training data used to fit the GP.
Returns:
-
state(ComputationAwareGPState[_LinearSolverState]) –State of the conditioned CaGP posterior, which stores any necessary intermediate values for prediction and computing objectives.
Source code in src/cagpjax/models/cagp.py
predict
predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs.
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init -
(test_inputsFloat[Array, 'N D'] | None, default:None) –The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.
Returns:
-
GaussianDistribution(GaussianDistribution) –The predictive distribution of the GP at the test inputs.
Source code in src/cagpjax/models/cagp.py
prior_kl
Compute KL divergence between CaGP posterior and GP prior..
Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
ScalarFloat–KL divergence value (scalar).
Source code in src/cagpjax/models/cagp.py
variational_expectation
Compute the variational expectation.
Compute the pointwise expected log-likelihood under the variational distribution.
Note
This should be used instead of gpjax.objectives.variational_expectation
Parameters:
-
(stateComputationAwareGPState[_LinearSolverState]) –State of the conditioned GP computed by
init
Returns:
-
expectation(Float[Array, N]) –The pointwise expected log-likelihood under the variational distribution.
Source code in src/cagpjax/models/cagp.py
ComputationAwareGPState
dataclass
ComputationAwareGPState(train_data: Dataset, actions: LinearOperator, obs_cov_proj: LinearOperator, cov_prior_proj_state: _LinearSolverState, residual_proj: Float[Array, M], repr_weights_proj: Float[Array, M])
Bases: Generic[_LinearSolverState]
Projected quantities for computation-aware GP inference.
Parameters:
-
(train_dataDataset) –Training data with N inputs with D dimensions.
-
(actionsLinearOperator) –Actions operator; transpose of operator projecting from N-dimensional space to M-dimensional subspace.
-
(obs_cov_projLinearOperator) –Projected covariance of likelihood.
-
(cov_prior_proj_state_LinearSolverState) –Linear solver state for
cov_prior_proj. -
(residual_projFloat[Array, M]) –Projected residuals between observations and prior mean.
-
(repr_weights_projFloat[Array, M]) –Projected representer weights.