Skip to content

cagpjax.models.cagp

Computation-aware Gaussian Process models.

Classes:

ComputationAwareGP

ComputationAwareGP(posterior: ConjugatePosterior, policy: AbstractBatchLinearSolverPolicy, solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-06))

Bases: Module, Generic[_LinearSolverState]

Computation-aware Gaussian Process model.

This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.

Attributes:

Notes
  • Only single-output models are currently supported.

Initialize the Computation-Aware GP model.

Parameters:

Methods:

  • elbo

    Compute the evidence lower bound.

  • init

    Compute the state of the conditioned GP posterior.

  • predict

    Compute the predictive distribution of the GP at the test inputs.

  • prior_kl

    Compute KL divergence between CaGP posterior and GP prior..

  • variational_expectation

    Compute the variational expectation.

Source code in src/cagpjax/models/cagp.py
def __init__(
    self,
    posterior: ConjugatePosterior,
    policy: AbstractBatchLinearSolverPolicy,
    solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-6),
):
    """Initialize the Computation-Aware GP model.

    Args:
        posterior: GPJax conjugate posterior.
        policy: The batch linear solver policy that defines the subspace into
            which the data is projected.
        solver: The linear solver method to use for solving linear systems with
            positive semi-definite operators.
    """
    self.posterior = posterior
    self.policy = policy
    self.solver = solver

elbo

elbo(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute the evidence lower bound.

Computes the evidence lower bound (ELBO) under this model's variational distribution.

Note

This should be used instead of gpjax.objectives.elbo

Parameters:

Returns:

  • ScalarFloat

    ELBO value (scalar).

Source code in src/cagpjax/models/cagp.py
def elbo(self, state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat:
    """Compute the evidence lower bound.

    Computes the evidence lower bound (ELBO) under this model's variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.elbo``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        ELBO value (scalar).
    """
    var_exp = self.variational_expectation(state)
    kl = self.prior_kl(state)
    return jnp.sum(var_exp) - kl

init

init(train_data: Dataset) -> ComputationAwareGPState[_LinearSolverState]

Compute the state of the conditioned GP posterior.

Parameters:

  • train_data

    (Dataset) –

    The training data used to fit the GP.

Returns:

  • state ( ComputationAwareGPState[_LinearSolverState] ) –

    State of the conditioned CaGP posterior, which stores any necessary intermediate values for prediction and computing objectives.

Source code in src/cagpjax/models/cagp.py
def init(self, train_data: Dataset) -> ComputationAwareGPState[_LinearSolverState]:
    """Compute the state of the conditioned GP posterior.

    Args:
        train_data: The training data used to fit the GP.

    Returns:
        state: State of the conditioned CaGP posterior, which stores any necessary
            intermediate values for prediction and computing objectives.
    """
    # Ensure we have supervised training data
    if train_data.X is None or train_data.y is None:
        raise ValueError("Training data must be supervised.")

    # Unpack training data
    x = jnp.atleast_2d(train_data.X)
    y = jnp.atleast_1d(train_data.y).squeeze()

    # Unpack prior and likelihood
    prior = self.posterior.prior
    likelihood = self.posterior.likelihood

    # Mean and covariance of prior-predictive distribution
    mean_prior = prior.mean_function(x).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = prior.mean_function.constant[...]
        mean_prior = mean_prior.astype(constant.dtype)
    cov_xx = lazify(prior.kernel.gram(x))
    obs_cov = diag_like(cov_xx, likelihood.obs_stddev[...] ** 2)
    cov_prior = cov_xx + obs_cov

    # Project quantities to subspace
    actions = self.policy.to_actions(cov_prior)
    obs_cov_proj = congruence_transform(actions, obs_cov)
    cov_prior_proj = congruence_transform(actions, cov_prior)
    cov_prior_proj_state = self.solver.init(cov_prior_proj)

    residual_proj = actions.T @ (y - mean_prior)
    repr_weights_proj = self.solver.solve(cov_prior_proj_state, residual_proj)

    return ComputationAwareGPState(
        train_data=train_data,
        actions=actions,
        obs_cov_proj=obs_cov_proj,
        cov_prior_proj_state=cov_prior_proj_state,
        residual_proj=residual_proj,
        repr_weights_proj=repr_weights_proj,
    )

predict

predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs.

Parameters:

  • state

    (ComputationAwareGPState[_LinearSolverState]) –

    State of the conditioned GP computed by init

  • test_inputs

    (Float[Array, 'N D'] | None, default: None ) –

    The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.

Returns:

  • GaussianDistribution ( GaussianDistribution ) –

    The predictive distribution of the GP at the test inputs.

Source code in src/cagpjax/models/cagp.py
def predict(
    self,
    state: ComputationAwareGPState[_LinearSolverState],
    test_inputs: Float[Array, "N D"] | None = None,
) -> GaussianDistribution:
    """Compute the predictive distribution of the GP at the test inputs.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]
        test_inputs: The test inputs at which to make predictions. If not provided,
            predictions are made at the training inputs.

    Returns:
        GaussianDistribution: The predictive distribution of the GP at the
            test inputs.
    """
    train_data = state.train_data
    assert train_data.X is not None  # help out pyright
    x = train_data.X

    # Predictions at test points
    z = test_inputs if test_inputs is not None else x
    prior = self.posterior.prior
    mean_z = prior.mean_function(z).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = prior.mean_function.constant[...]
        mean_z = mean_z.astype(constant.dtype)
    cov_zz = lazify(prior.kernel.gram(z))
    cov_zx = cov_zz if test_inputs is None else prior.kernel.cross_covariance(z, x)
    cov_zx_proj = cov_zx @ state.actions

    # Posterior predictive distribution
    mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ state.repr_weights_proj)
    cov_pred = cov_zz - self.solver.inv_congruence_transform(
        state.cov_prior_proj_state, cov_zx_proj.T
    )
    cov_pred = cola.PSD(cov_pred)

    return GaussianDistribution(mean_pred, cov_pred, solver=self.solver)

prior_kl

prior_kl(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute KL divergence between CaGP posterior and GP prior..

Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.

Parameters:

Returns:

  • ScalarFloat

    KL divergence value (scalar).

Source code in src/cagpjax/models/cagp.py
def prior_kl(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> ScalarFloat:
    r"""Compute KL divergence between CaGP posterior and GP prior..

    Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
    posterior approximation and $p(f)$ is the GP prior.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        KL divergence value (scalar).
    """
    obs_cov_proj_solver_state = self.solver.init(state.obs_cov_proj)

    kl = (
        _kl_divergence_from_solvers(
            self.solver,
            state.residual_proj,
            obs_cov_proj_solver_state,
            jnp.zeros_like(state.residual_proj),
            state.cov_prior_proj_state,
        )
        - 0.5
        * congruence_transform(
            state.repr_weights_proj.T, state.obs_cov_proj
        ).squeeze()
    )

    return kl

variational_expectation

variational_expectation(state: ComputationAwareGPState[_LinearSolverState]) -> Float[Array, N]

Compute the variational expectation.

Compute the pointwise expected log-likelihood under the variational distribution.

Note

This should be used instead of gpjax.objectives.variational_expectation

Parameters:

Returns:

  • expectation ( Float[Array, N] ) –

    The pointwise expected log-likelihood under the variational distribution.

Source code in src/cagpjax/models/cagp.py
def variational_expectation(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> Float[Array, "N"]:
    """Compute the variational expectation.

    Compute the pointwise expected log-likelihood under the variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.variational_expectation``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        expectation: The pointwise expected log-likelihood under the variational distribution.
    """

    # Unpack data
    y = state.train_data.y

    # Predict and compute expectation
    qpred = self.predict(state)
    mean = qpred.mean
    variance = qpred.variance
    expectation = self.posterior.likelihood.expected_log_likelihood(
        y, mean[:, None], variance[:, None]
    )

    return expectation

ComputationAwareGPState dataclass

ComputationAwareGPState(train_data: Dataset, actions: LinearOperator, obs_cov_proj: LinearOperator, cov_prior_proj_state: _LinearSolverState, residual_proj: Float[Array, M], repr_weights_proj: Float[Array, M])

Bases: Generic[_LinearSolverState]

Projected quantities for computation-aware GP inference.

Parameters:

  • train_data

    (Dataset) –

    Training data with N inputs with D dimensions.

  • actions

    (LinearOperator) –

    Actions operator; transpose of operator projecting from N-dimensional space to M-dimensional subspace.

  • obs_cov_proj

    (LinearOperator) –

    Projected covariance of likelihood.

  • cov_prior_proj_state

    (_LinearSolverState) –

    Linear solver state for cov_prior_proj.

  • residual_proj

    (Float[Array, M]) –

    Projected residuals between observations and prior mean.

  • repr_weights_proj

    (Float[Array, M]) –

    Projected representer weights.