Skip to content

cagpjax

Computation-Aware Gaussian Processes for GPJax.

Modules:

Classes:

BlockDiagonalSparse

BlockDiagonalSparse(nz_values: Float[Array, N], n_blocks: int)

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.

Parameters:

  • nz_values

    (Float[Array, N]) –

    Non-zero values to be distributed across diagonal blocks.

  • n_blocks

    (int) –

    Number of diagonal blocks in the matrix.

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
       [2., 0., 0.],
       [0., 3., 0.],
       [0., 4., 0.],
       [0., 0., 5.],
       [0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
    n = nz_values.shape[0]
    super().__init__(nz_values.dtype, (n, n_blocks), annotations={ScaledOrthogonal})
    self.nz_values = nz_values

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Attributes:

  • n_actions (int) –

    Number of actions to use.

  • nz_values (Float[Array, N] | AbstractUnwrappable[Float[Array, N]]) –

    Non-zero values of the block-diagonal sparse matrix.

Methods:

  • from_random

    Initialize policy from block-normalized random samples.

  • to_actions

    Convert to block diagonal sparse action operators.

from_random classmethod

from_random(key: PRNGKeyArray, num_datapoints: int, n_actions: int, *, sampler: Callable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']] = jax.random.normal, dtype: Any = None) -> BlockSparsePolicy

Initialize policy from block-normalized random samples.

Parameters:

  • key

    (PRNGKeyArray) –

    Random key used to sample initial values.

  • num_datapoints

    (int) –

    Number of rows in the resulting operator.

  • n_actions

    (int) –

    Number of action columns in the resulting operator.

  • sampler

    (Callable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']], default: normal ) –

    Callable with signature (key, shape, dtype) -> values.

  • dtype

    (Any, default: None ) –

    Optional dtype forwarded to sampler.

Source code in src/cagpjax/policies/block_sparse.py
@classmethod
def from_random(
    cls,
    key: PRNGKeyArray,
    num_datapoints: int,
    n_actions: int,
    *,
    sampler: Callable[
        [PRNGKeyArray, tuple[int, ...], Any], Float[Array, " N"]
    ] = jax.random.normal,
    dtype: Any = None,
) -> "BlockSparsePolicy":
    """Initialize policy from block-normalized random samples.

    Args:
        key: Random key used to sample initial values.
        num_datapoints: Number of rows in the resulting operator.
        n_actions: Number of action columns in the resulting operator.
        sampler: Callable with signature ``(key, shape, dtype) -> values``.
        dtype: Optional dtype forwarded to ``sampler``.
    """
    if num_datapoints < 1:
        raise ValueError("num_datapoints must be at least 1")
    nz_values = sampler(key, (num_datapoints,), dtype)
    nz_values = _normalize_by_blocks(nz_values, n_actions)
    return cls(n_actions=n_actions, nz_values=nz_values)

to_actions

to_actions(A: LinearOperator, *, key: PRNGKeyArray | None = None) -> LinearOperator

Convert to block diagonal sparse action operators.

Parameters:

  • A

    (LinearOperator) –

    Linear operator (unused).

  • key

    (PRNGKeyArray | None, default: None ) –

    Optional random key (unused).

Returns:

  • BlockDiagonalSparse ( LinearOperator ) –

    Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(
    self, A: LinearOperator, *, key: PRNGKeyArray | None = None
) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).
        key: Optional random key (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(paramax.unwrap(self.nz_values), self.n_actions)

ComputationAwareGP

Bases: Module, Generic[_LinearSolverState]

Computation-aware Gaussian Process model.

This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.

Attributes:

  • posterior (ConjugatePosterior) –

    The original (exact) posterior.

  • policy (AbstractBatchLinearSolverPolicy) –

    The batch linear solver policy that defines the subspace into which the data is projected.

  • solver (AbstractLinearSolver[_LinearSolverState]) –

    The linear solver method to use for solving linear systems with positive semi-definite operators.

Notes
  • Only single-output models are currently supported.

Methods:

  • elbo

    Compute the evidence lower bound.

  • init

    Compute the state of the conditioned GP posterior.

  • predict

    Compute the predictive distribution of the GP at the test inputs.

  • prior_kl

    Compute KL divergence between CaGP posterior and GP prior..

  • variational_expectation

    Compute the variational expectation.

elbo

elbo(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute the evidence lower bound.

Computes the evidence lower bound (ELBO) under this model's variational distribution.

Note

This should be used instead of gpjax.objectives.elbo

Parameters:

Returns:

  • ScalarFloat

    ELBO value (scalar).

Source code in src/cagpjax/models/cagp.py
def elbo(self, state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat:
    """Compute the evidence lower bound.

    Computes the evidence lower bound (ELBO) under this model's variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.elbo``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        ELBO value (scalar).
    """
    var_exp = self.variational_expectation(state)
    kl = self.prior_kl(state)
    return jnp.sum(var_exp) - kl

init

init(train_data: Dataset, *, key: PRNGKeyArray | None = None) -> ComputationAwareGPState[_LinearSolverState]

Compute the state of the conditioned GP posterior.

Parameters:

  • train_data

    (Dataset) –

    The training data used to fit the GP.

  • key

    (PRNGKeyArray | None, default: None ) –

    Optional random key forwarded to policy action construction.

Returns:

  • state ( ComputationAwareGPState[_LinearSolverState] ) –

    State of the conditioned CaGP posterior, which stores any necessary intermediate values for prediction and computing objectives.

Source code in src/cagpjax/models/cagp.py
def init(
    self, train_data: Dataset, *, key: PRNGKeyArray | None = None
) -> ComputationAwareGPState[_LinearSolverState]:
    """Compute the state of the conditioned GP posterior.

    Args:
        train_data: The training data used to fit the GP.
        key: Optional random key forwarded to policy action construction.

    Returns:
        state: State of the conditioned CaGP posterior, which stores any necessary
            intermediate values for prediction and computing objectives.
    """
    # Ensure we have supervised training data
    if train_data.X is None or train_data.y is None:
        raise ValueError("Training data must be supervised.")

    # Unpack training data
    x = jnp.atleast_2d(train_data.X)
    y = jnp.atleast_1d(train_data.y).squeeze()

    # Unpack prior and likelihood
    prior = self.posterior.prior
    likelihood = self.posterior.likelihood

    # Mean and covariance of prior-predictive distribution
    mean_prior = prior.mean_function(x).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = paramax.unwrap(prior.mean_function.constant)
        mean_prior = mean_prior.astype(constant.dtype)
    cov_xx = lazify(prior.kernel.gram(x))
    obs_cov = diag_like(cov_xx, paramax.unwrap(likelihood.obs_stddev) ** 2)
    cov_prior = cov_xx + obs_cov

    # Project quantities to subspace
    actions = self.policy.to_actions(cov_prior, key=key)
    obs_cov_proj = congruence_transform(actions, obs_cov)
    cov_prior_proj = congruence_transform(actions, cov_prior)
    cov_prior_proj_state = self.solver.init(cov_prior_proj)

    residual_proj = actions.T @ (y - mean_prior)
    repr_weights_proj = self.solver.solve(cov_prior_proj_state, residual_proj)

    return ComputationAwareGPState(
        train_data=train_data,
        actions=actions,
        obs_cov_proj=obs_cov_proj,
        cov_prior_proj_state=cov_prior_proj_state,
        residual_proj=residual_proj,
        repr_weights_proj=repr_weights_proj,
    )

predict

predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs.

Parameters:

  • state

    (ComputationAwareGPState[_LinearSolverState]) –

    State of the conditioned GP computed by init

  • test_inputs

    (Float[Array, 'N D'] | None, default: None ) –

    The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.

Returns:

  • GaussianDistribution ( GaussianDistribution ) –

    The predictive distribution of the GP at the test inputs.

Source code in src/cagpjax/models/cagp.py
def predict(
    self,
    state: ComputationAwareGPState[_LinearSolverState],
    test_inputs: Float[Array, "N D"] | None = None,
) -> GaussianDistribution:
    """Compute the predictive distribution of the GP at the test inputs.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]
        test_inputs: The test inputs at which to make predictions. If not provided,
            predictions are made at the training inputs.

    Returns:
        GaussianDistribution: The predictive distribution of the GP at the
            test inputs.
    """
    train_data = state.train_data
    assert train_data.X is not None  # help out pyright
    x = train_data.X

    # Predictions at test points
    z = test_inputs if test_inputs is not None else x
    prior = self.posterior.prior
    mean_z = prior.mean_function(z).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = paramax.unwrap(prior.mean_function.constant)
        mean_z = mean_z.astype(constant.dtype)
    cov_zz = lazify(prior.kernel.gram(z))
    cov_zx = (
        cov_zz
        if test_inputs is None
        else lazify(prior.kernel.cross_covariance(z, x))
    )
    cov_zx_proj = cov_zx @ state.actions

    # Posterior predictive distribution
    mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ state.repr_weights_proj)
    cov_pred = cov_zz - self.solver.inv_congruence_transform(
        state.cov_prior_proj_state, cov_zx_proj.T
    )
    cov_pred = cola.PSD(cov_pred)

    return GaussianDistribution(mean_pred, cov_pred, solver=self.solver)

prior_kl

prior_kl(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute KL divergence between CaGP posterior and GP prior..

Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.

Parameters:

Returns:

  • ScalarFloat

    KL divergence value (scalar).

Source code in src/cagpjax/models/cagp.py
def prior_kl(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> ScalarFloat:
    r"""Compute KL divergence between CaGP posterior and GP prior..

    Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
    posterior approximation and $p(f)$ is the GP prior.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        KL divergence value (scalar).
    """
    obs_cov_proj_solver_state = self.solver.init(state.obs_cov_proj)

    kl = (
        _kl_divergence_from_solvers(
            self.solver,
            state.residual_proj,
            obs_cov_proj_solver_state,
            jnp.zeros_like(state.residual_proj),
            state.cov_prior_proj_state,
        )
        - 0.5
        * congruence_transform(
            state.repr_weights_proj.T, state.obs_cov_proj
        ).squeeze()
    )

    return kl

variational_expectation

variational_expectation(state: ComputationAwareGPState[_LinearSolverState]) -> Float[Array, N]

Compute the variational expectation.

Compute the pointwise expected log-likelihood under the variational distribution.

Note

This should be used instead of gpjax.objectives.variational_expectation

Parameters:

Returns:

  • expectation ( Float[Array, N] ) –

    The pointwise expected log-likelihood under the variational distribution.

Source code in src/cagpjax/models/cagp.py
def variational_expectation(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> Float[Array, "N"]:
    """Compute the variational expectation.

    Compute the pointwise expected log-likelihood under the variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.variational_expectation``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        expectation: The pointwise expected log-likelihood under the variational distribution.
    """

    # Unpack data
    y = state.train_data.y

    # Predict and compute expectation
    qpred = self.predict(state)
    mean = qpred.mean
    variance = qpred.variance
    expectation = self.posterior.likelihood.expected_log_likelihood(
        y, mean[:, None], variance[:, None]
    )

    return expectation

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

  • n_actions (int) –

    Number of Lanczos vectors/actions to compute.

  • grad_rtol (float | None) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

Methods:

to_actions

to_actions(A: LinearOperator, *, key: PRNGKeyArray | None = None) -> LinearOperator

Compute action matrix.

Parameters:

  • A

    (LinearOperator) –

    Symmetric linear operator representing the linear system.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key used to initialize the Lanczos run.

Returns:

  • LinearOperator

    Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(
    self, A: LinearOperator, *, key: PRNGKeyArray | None = None
) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.
        key: Random key used to initialize the Lanczos run.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs