Skip to content

cagpjax

Computation-Aware Gaussian Processes for GPJax.

Modules:

Classes:

BlockDiagonalSparse

BlockDiagonalSparse(nz_values: Float[Array, N], n_blocks: int)

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.

Parameters:

  • nz_values

    (Float[Array, N]) –

    Non-zero values to be distributed across diagonal blocks.

  • n_blocks

    (int) –

    Number of diagonal blocks in the matrix.

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
       [2., 0., 0.],
       [0., 3., 0.],
       [0., 4., 0.],
       [0., 0., 5.],
       [0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
    n = nz_values.shape[0]
    super().__init__(nz_values.dtype, (n, n_blocks), annotations={ScaledOrthogonal})
    self.nz_values = nz_values

BlockSparsePolicy

BlockSparsePolicy(n_actions: int, n: int | None = None, nz_values: Float[Array, N] | Variable[Float[Array, N]] | None = None, key: PRNGKeyArray | None = None, **kwargs)

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Initialize the block sparse policy.

Parameters:

  • n_actions

    (int) –

    Number of actions to use.

  • n

    (int | None, default: None ) –

    Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

  • nz_values

    (Float[Array, N] | Variable[Float[Array, N]] | None, default: None ) –

    Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key for sampling actions if nz_values is not provided.

  • **kwargs

    Additional keyword arguments for jax.random.normal (e.g. dtype)

Methods:

  • to_actions

    Convert to block diagonal sparse action operators.

Attributes:

  • n_actions (int) –

    Number of actions to be used.

Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

n_actions property

n_actions: int

Number of actions to be used.

to_actions

to_actions(A: LinearOperator) -> LinearOperator

Convert to block diagonal sparse action operators.

Parameters:

Returns:

  • BlockDiagonalSparse ( LinearOperator ) –

    Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values[...], self.n_actions)

ComputationAwareGP

ComputationAwareGP(posterior: ConjugatePosterior, policy: AbstractBatchLinearSolverPolicy, solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-06))

Bases: Module, Generic[_LinearSolverState]

Computation-aware Gaussian Process model.

This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.

Attributes:

Notes
  • Only single-output models are currently supported.

Initialize the Computation-Aware GP model.

Parameters:

Methods:

  • elbo

    Compute the evidence lower bound.

  • init

    Compute the state of the conditioned GP posterior.

  • predict

    Compute the predictive distribution of the GP at the test inputs.

  • prior_kl

    Compute KL divergence between CaGP posterior and GP prior..

  • variational_expectation

    Compute the variational expectation.

Source code in src/cagpjax/models/cagp.py
def __init__(
    self,
    posterior: ConjugatePosterior,
    policy: AbstractBatchLinearSolverPolicy,
    solver: AbstractLinearSolver[_LinearSolverState] = Cholesky(1e-6),
):
    """Initialize the Computation-Aware GP model.

    Args:
        posterior: GPJax conjugate posterior.
        policy: The batch linear solver policy that defines the subspace into
            which the data is projected.
        solver: The linear solver method to use for solving linear systems with
            positive semi-definite operators.
    """
    self.posterior = posterior
    self.policy = policy
    self.solver = solver

elbo

elbo(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute the evidence lower bound.

Computes the evidence lower bound (ELBO) under this model's variational distribution.

Note

This should be used instead of gpjax.objectives.elbo

Parameters:

Returns:

  • ScalarFloat

    ELBO value (scalar).

Source code in src/cagpjax/models/cagp.py
def elbo(self, state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat:
    """Compute the evidence lower bound.

    Computes the evidence lower bound (ELBO) under this model's variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.elbo``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        ELBO value (scalar).
    """
    var_exp = self.variational_expectation(state)
    kl = self.prior_kl(state)
    return jnp.sum(var_exp) - kl

init

init(train_data: Dataset) -> ComputationAwareGPState[_LinearSolverState]

Compute the state of the conditioned GP posterior.

Parameters:

  • train_data

    (Dataset) –

    The training data used to fit the GP.

Returns:

  • state ( ComputationAwareGPState[_LinearSolverState] ) –

    State of the conditioned CaGP posterior, which stores any necessary intermediate values for prediction and computing objectives.

Source code in src/cagpjax/models/cagp.py
def init(self, train_data: Dataset) -> ComputationAwareGPState[_LinearSolverState]:
    """Compute the state of the conditioned GP posterior.

    Args:
        train_data: The training data used to fit the GP.

    Returns:
        state: State of the conditioned CaGP posterior, which stores any necessary
            intermediate values for prediction and computing objectives.
    """
    # Ensure we have supervised training data
    if train_data.X is None or train_data.y is None:
        raise ValueError("Training data must be supervised.")

    # Unpack training data
    x = jnp.atleast_2d(train_data.X)
    y = jnp.atleast_1d(train_data.y).squeeze()

    # Unpack prior and likelihood
    prior = self.posterior.prior
    likelihood = self.posterior.likelihood

    # Mean and covariance of prior-predictive distribution
    mean_prior = prior.mean_function(x).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = prior.mean_function.constant[...]
        mean_prior = mean_prior.astype(constant.dtype)
    cov_xx = lazify(prior.kernel.gram(x))
    obs_cov = diag_like(cov_xx, likelihood.obs_stddev[...] ** 2)
    cov_prior = cov_xx + obs_cov

    # Project quantities to subspace
    actions = self.policy.to_actions(cov_prior)
    obs_cov_proj = congruence_transform(actions, obs_cov)
    cov_prior_proj = congruence_transform(actions, cov_prior)
    cov_prior_proj_state = self.solver.init(cov_prior_proj)

    residual_proj = actions.T @ (y - mean_prior)
    repr_weights_proj = self.solver.solve(cov_prior_proj_state, residual_proj)

    return ComputationAwareGPState(
        train_data=train_data,
        actions=actions,
        obs_cov_proj=obs_cov_proj,
        cov_prior_proj_state=cov_prior_proj_state,
        residual_proj=residual_proj,
        repr_weights_proj=repr_weights_proj,
    )

predict

predict(state: ComputationAwareGPState[_LinearSolverState], test_inputs: Float[Array, 'N D'] | None = None) -> GaussianDistribution

Compute the predictive distribution of the GP at the test inputs.

Parameters:

  • state

    (ComputationAwareGPState[_LinearSolverState]) –

    State of the conditioned GP computed by init

  • test_inputs

    (Float[Array, 'N D'] | None, default: None ) –

    The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.

Returns:

  • GaussianDistribution ( GaussianDistribution ) –

    The predictive distribution of the GP at the test inputs.

Source code in src/cagpjax/models/cagp.py
def predict(
    self,
    state: ComputationAwareGPState[_LinearSolverState],
    test_inputs: Float[Array, "N D"] | None = None,
) -> GaussianDistribution:
    """Compute the predictive distribution of the GP at the test inputs.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]
        test_inputs: The test inputs at which to make predictions. If not provided,
            predictions are made at the training inputs.

    Returns:
        GaussianDistribution: The predictive distribution of the GP at the
            test inputs.
    """
    train_data = state.train_data
    assert train_data.X is not None  # help out pyright
    x = train_data.X

    # Predictions at test points
    z = test_inputs if test_inputs is not None else x
    prior = self.posterior.prior
    mean_z = prior.mean_function(z).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        constant = prior.mean_function.constant[...]
        mean_z = mean_z.astype(constant.dtype)
    cov_zz = lazify(prior.kernel.gram(z))
    cov_zx = cov_zz if test_inputs is None else prior.kernel.cross_covariance(z, x)
    cov_zx_proj = cov_zx @ state.actions

    # Posterior predictive distribution
    mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ state.repr_weights_proj)
    cov_pred = cov_zz - self.solver.inv_congruence_transform(
        state.cov_prior_proj_state, cov_zx_proj.T
    )
    cov_pred = cola.PSD(cov_pred)

    return GaussianDistribution(mean_pred, cov_pred, solver=self.solver)

prior_kl

prior_kl(state: ComputationAwareGPState[_LinearSolverState]) -> ScalarFloat

Compute KL divergence between CaGP posterior and GP prior..

Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.

Parameters:

Returns:

  • ScalarFloat

    KL divergence value (scalar).

Source code in src/cagpjax/models/cagp.py
def prior_kl(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> ScalarFloat:
    r"""Compute KL divergence between CaGP posterior and GP prior..

    Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
    posterior approximation and $p(f)$ is the GP prior.

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        KL divergence value (scalar).
    """
    obs_cov_proj_solver_state = self.solver.init(state.obs_cov_proj)

    kl = (
        _kl_divergence_from_solvers(
            self.solver,
            state.residual_proj,
            obs_cov_proj_solver_state,
            jnp.zeros_like(state.residual_proj),
            state.cov_prior_proj_state,
        )
        - 0.5
        * congruence_transform(
            state.repr_weights_proj.T, state.obs_cov_proj
        ).squeeze()
    )

    return kl

variational_expectation

variational_expectation(state: ComputationAwareGPState[_LinearSolverState]) -> Float[Array, N]

Compute the variational expectation.

Compute the pointwise expected log-likelihood under the variational distribution.

Note

This should be used instead of gpjax.objectives.variational_expectation

Parameters:

Returns:

  • expectation ( Float[Array, N] ) –

    The pointwise expected log-likelihood under the variational distribution.

Source code in src/cagpjax/models/cagp.py
def variational_expectation(
    self, state: ComputationAwareGPState[_LinearSolverState]
) -> Float[Array, "N"]:
    """Compute the variational expectation.

    Compute the pointwise expected log-likelihood under the variational distribution.

    Note:
        This should be used instead of ``gpjax.objectives.variational_expectation``

    Args:
        state: State of the conditioned GP computed by [`init`][..init]

    Returns:
        expectation: The pointwise expected log-likelihood under the variational distribution.
    """

    # Unpack data
    y = state.train_data.y

    # Predict and compute expectation
    qpred = self.predict(state)
    mean = qpred.mean
    variance = qpred.variance
    expectation = self.posterior.likelihood.expected_log_likelihood(
        y, mean[:, None], variance[:, None]
    )

    return expectation

LanczosPolicy

LanczosPolicy(n_actions: int | None, key: PRNGKeyArray | None = None, grad_rtol: float | None = 0.0)

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

  • n_actions (int) –

    Number of Lanczos vectors/actions to compute.

  • key (PRNGKeyArray | None) –

    Random key for reproducible Lanczos iterations.

Initialize the Lanczos policy.

Parameters:

  • n_actions

    (int | None) –

    Number of Lanczos vectors to compute.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key for initialization.

  • grad_rtol

    (float | None, default: 0.0 ) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

Methods:

Source code in src/cagpjax/policies/lanczos.py
def __init__(
    self,
    n_actions: int | None,
    key: PRNGKeyArray | None = None,
    grad_rtol: float | None = 0.0,
):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
            (see [`cagpjax.linalg.eigh`][] for more details)
    """
    self._n_actions: int = n_actions
    self.key = key
    self.grad_rtol = grad_rtol

to_actions

to_actions(A: LinearOperator) -> LinearOperator

Compute action matrix.

Parameters:

  • A

    (LinearOperator) –

    Symmetric linear operator representing the linear system.

Returns:

  • LinearOperator

    Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs