Bases: Module, Generic[_LinearSolverState]
Computation-aware Gaussian Process model.
This model implements scalable GP inference by using batch linear solver
policies to project the kernel and data to a lower-dimensional subspace, while
accounting for the extra uncertainty imposed by observing only this subspace.
Attributes:
-
posterior
(ConjugatePosterior)
–
The original (exact) posterior.
-
policy
(AbstractBatchLinearSolverPolicy)
–
The batch linear solver policy that defines the subspace into
which the data is projected.
-
solver
(AbstractLinearSolver[_LinearSolverState])
–
The linear solver method to use for solving linear systems
with positive semi-definite operators.
Notes
- Only single-output models are currently supported.
Methods:
-
elbo
–
Compute the evidence lower bound.
-
init
–
Compute the state of the conditioned GP posterior.
-
predict
–
Compute the predictive distribution of the GP at the test inputs.
-
prior_kl
–
Compute KL divergence between CaGP posterior and GP prior..
-
variational_expectation
–
Compute the variational expectation.
elbo
elbo(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat
Compute the evidence lower bound.
Computes the evidence lower bound (ELBO) under this model's variational distribution.
Note
This should be used instead of gpjax.objectives.elbo
Parameters:
Returns:
Source code in src/cagpjax/models/cagp.py
| def elbo(self, state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat:
"""Compute the evidence lower bound.
Computes the evidence lower bound (ELBO) under this model's variational distribution.
Note:
This should be used instead of ``gpjax.objectives.elbo``
Args:
state: State of the conditioned GP computed by [`init`][..init]
Returns:
ELBO value (scalar).
"""
var_exp = self.variational_expectation(state)
kl = self.prior_kl(state)
return jnp.sum(var_exp) - kl
|
init
init(train_data: Dataset, *, key: PRNGKeyArray | None = None) -> ComputationAwareGPState[_LinearSolverState]
Compute the state of the conditioned GP posterior.
Parameters:
-
train_data
(Dataset)
–
The training data used to fit the GP.
-
key
(PRNGKeyArray | None, default:
None
)
–
Optional random key forwarded to policy action construction.
Returns:
-
state ( ComputationAwareGPState[_LinearSolverState]
) –
State of the conditioned CaGP posterior, which stores any necessary
intermediate values for prediction and computing objectives.
Source code in src/cagpjax/models/cagp.py
| def init(
self, train_data: Dataset, *, key: PRNGKeyArray | None = None
) -> ComputationAwareGPState[_LinearSolverState]:
"""Compute the state of the conditioned GP posterior.
Args:
train_data: The training data used to fit the GP.
key: Optional random key forwarded to policy action construction.
Returns:
state: State of the conditioned CaGP posterior, which stores any necessary
intermediate values for prediction and computing objectives.
"""
# Ensure we have supervised training data
if train_data.X is None or train_data.y is None:
raise ValueError("Training data must be supervised.")
# Unpack training data
x = jnp.atleast_2d(train_data.X)
y = jnp.atleast_1d(train_data.y).squeeze()
# Unpack prior and likelihood
prior = self.posterior.prior
likelihood = self.posterior.likelihood
# Mean and covariance of prior-predictive distribution
mean_prior = prior.mean_function(x).squeeze()
# Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
if isinstance(prior.mean_function, Constant):
constant = paramax.unwrap(prior.mean_function.constant)
mean_prior = mean_prior.astype(constant.dtype)
cov_xx = lazify(prior.kernel.gram(x))
obs_cov = diag_like(cov_xx, paramax.unwrap(likelihood.obs_stddev) ** 2)
cov_prior = cov_xx + obs_cov
# Project quantities to subspace
actions = self.policy.to_actions(cov_prior, key=key)
obs_cov_proj = congruence_transform(actions, obs_cov)
cov_prior_proj = congruence_transform(actions, cov_prior)
cov_prior_proj_state = self.solver.init(cov_prior_proj)
residual_proj = actions.T @ (y - mean_prior)
repr_weights_proj = self.solver.solve(cov_prior_proj_state, residual_proj)
return ComputationAwareGPState(
train_data=train_data,
actions=actions,
obs_cov_proj=obs_cov_proj,
cov_prior_proj_state=cov_prior_proj_state,
residual_proj=residual_proj,
repr_weights_proj=repr_weights_proj,
)
|
predict
predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution
Compute the predictive distribution of the GP at the test inputs.
Parameters:
-
state
(ComputationAwareGPState[_LinearSolverState])
–
State of the conditioned GP computed by init
-
(
Float[Array, 'N D'] | None, default:
None
)
–
The test inputs at which to make predictions. If not provided,
predictions are made at the training inputs.
Returns:
-
GaussianDistribution ( GaussianDistribution
) –
The predictive distribution of the GP at the
test inputs.
Source code in src/cagpjax/models/cagp.py
| def predict(
self,
state: ComputationAwareGPState[_LinearSolverState],
test_inputs: Float[Array, "N D"] | None = None,
) -> GaussianDistribution:
"""Compute the predictive distribution of the GP at the test inputs.
Args:
state: State of the conditioned GP computed by [`init`][..init]
test_inputs: The test inputs at which to make predictions. If not provided,
predictions are made at the training inputs.
Returns:
GaussianDistribution: The predictive distribution of the GP at the
test inputs.
"""
train_data = state.train_data
assert train_data.X is not None # help out pyright
x = train_data.X
# Predictions at test points
z = test_inputs if test_inputs is not None else x
prior = self.posterior.prior
mean_z = prior.mean_function(z).squeeze()
# Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
if isinstance(prior.mean_function, Constant):
constant = paramax.unwrap(prior.mean_function.constant)
mean_z = mean_z.astype(constant.dtype)
cov_zz = lazify(prior.kernel.gram(z))
cov_zx = (
cov_zz
if test_inputs is None
else lazify(prior.kernel.cross_covariance(z, x))
)
cov_zx_proj = cov_zx @ state.actions
# Posterior predictive distribution
mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ state.repr_weights_proj)
cov_pred = cov_zz - self.solver.inv_congruence_transform(
state.cov_prior_proj_state, cov_zx_proj.T
)
cov_pred = cola.PSD(cov_pred)
return GaussianDistribution(mean_pred, cov_pred, solver=self.solver)
|
prior_kl
prior_kl(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat
Compute KL divergence between CaGP posterior and GP prior..
Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP
posterior approximation and \(p(f)\) is the GP prior.
Parameters:
Returns:
-
ScalarFloat
–
KL divergence value (scalar).
Source code in src/cagpjax/models/cagp.py
| def prior_kl(
self, state: ComputationAwareGPState[_LinearSolverState]
) -> ScalarFloat:
r"""Compute KL divergence between CaGP posterior and GP prior..
Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
posterior approximation and $p(f)$ is the GP prior.
Args:
state: State of the conditioned GP computed by [`init`][..init]
Returns:
KL divergence value (scalar).
"""
obs_cov_proj_solver_state = self.solver.init(state.obs_cov_proj)
kl = (
_kl_divergence_from_solvers(
self.solver,
state.residual_proj,
obs_cov_proj_solver_state,
jnp.zeros_like(state.residual_proj),
state.cov_prior_proj_state,
)
- 0.5
* congruence_transform(
state.repr_weights_proj.T, state.obs_cov_proj
).squeeze()
)
return kl
|
variational_expectation
variational_expectation(state: ComputationAwareGPState[_LinearSolverState]) -> Float[Array, N]
Compute the variational expectation.
Compute the pointwise expected log-likelihood under the variational distribution.
Note
This should be used instead of gpjax.objectives.variational_expectation
Parameters:
Returns:
-
expectation ( Float[Array, N]
) –
The pointwise expected log-likelihood under the variational distribution.
Source code in src/cagpjax/models/cagp.py
| def variational_expectation(
self, state: ComputationAwareGPState[_LinearSolverState]
) -> Float[Array, "N"]:
"""Compute the variational expectation.
Compute the pointwise expected log-likelihood under the variational distribution.
Note:
This should be used instead of ``gpjax.objectives.variational_expectation``
Args:
state: State of the conditioned GP computed by [`init`][..init]
Returns:
expectation: The pointwise expected log-likelihood under the variational distribution.
"""
# Unpack data
y = state.train_data.y
# Predict and compute expectation
qpred = self.predict(state)
mean = qpred.mean
variance = qpred.variance
expectation = self.posterior.likelihood.expected_log_likelihood(
y, mean[:, None], variance[:, None]
)
return expectation
|