Skip to content

cagpjax

Computation-Aware Gaussian Processes for GPJax.

BlockDiagonalSparse

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a row vector, so that exactly one value is non-zero in each column.

Parameters:

Name Type Description Default
nz_values Float[Array, N]

Non-zero values to be distributed across diagonal blocks.

required
n_blocks int

Number of diagonal blocks in the matrix.

required

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(3, 6)
>>>
>>> # Apply to a vector
>>> x = jnp.ones(6)
>>> result = op @ x
Source code in src/cagpjax/operators/block_diagonal_sparse.py
class BlockDiagonalSparse(LinearOperator):
    """Block-diagonal sparse linear operator.

    This operator represents a block-diagonal matrix structure where the blocks are contiguous, and
    each contains a row vector, so that exactly one value is non-zero in each column.

    Args:
        nz_values: Non-zero values to be distributed across diagonal blocks.
        n_blocks: Number of diagonal blocks in the matrix.

    Examples
    --------
    ```python
    >>> import jax.numpy as jnp
    >>> from cagpjax.operators import BlockDiagonalSparse
    >>>
    >>> # Create a 3x6 block-diagonal matrix with 3 blocks
    >>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    >>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
    >>> print(op.shape)
    (3, 6)
    >>>
    >>> # Apply to a vector
    >>> x = jnp.ones(6)
    >>> result = op @ x
    ```
    """

    def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
        n = nz_values.shape[0]
        super().__init__(nz_values.dtype, (n_blocks, n))
        self.nz_values = nz_values

    def _matmat(self, X: Float[Array, "N #M"]) -> Float[Array, "K #M"]:
        # figure out size of main blocks
        n_blocks, n = self.shape
        block_size = n // n_blocks
        n_blocks_main = n_blocks if n % n_blocks == 0 else n_blocks - 1
        n_main = n_blocks_main * block_size

        # block-wise multiplication for main blocks
        if n_blocks_main > 0:
            blocks_main = self.nz_values[:n_main].reshape(n_blocks_main, block_size)
            X_main = X[:n_main, ...].reshape(n_blocks_main, block_size, -1)
            res_main = jnp.einsum("ik,ikj->ij", blocks_main, X_main)
        else:
            res_main = jnp.empty((0, *X.shape[1:]), dtype=X.dtype)

        # handle overhang if any
        if n > n_main:
            n_overhang = n - n_main
            X_overhang = X[n_main:, ...].reshape(n_overhang, -1)
            block_overhang = self.nz_values[n_main:]
            res_overhang = block_overhang[None, :] @ X_overhang
            res = jnp.concatenate([res_main, res_overhang], axis=0)
        else:
            res = res_main

        return res.reshape(-1, *X.shape[1:])

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Source code in src/cagpjax/policies/block_sparse.py
class BlockSparsePolicy(AbstractBatchLinearSolverPolicy):
    r"""Block-sparse linear solver policy.

    This policy uses a fixed block-diagonal sparse structure to define
    independent learnable actions. The matrix has the following structure:

    $$
    S = \begin{bmatrix}
        s_1 & 0 & \cdots & 0 \\
        0 & s_2 & \cdots & 0 \\
        \vdots & \vdots & \ddots & \vdots \\
        0 & 0 & \cdots & s_{\text{n_actions}}
    \end{bmatrix}
    $$

    These are stacked and stored as a single trainable parameter ``nz_values``.
    """

    def __init__(
        self,
        n_actions: int,
        n: int | None = None,
        nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
        key: PRNGKeyArray | None = None,
        **kwargs,
    ):
        """Initialize the block sparse policy.

        Args:
            n_actions: Number of actions to use.
            n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
                not provided.
            nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
                provided, random actions are sampled using the key if provided.
            key: Random key for sampling actions if ``nz_values`` is not provided.
            **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
        """
        if nz_values is None:
            if n is None:
                raise ValueError("n must be provided if nz_values is not provided")
            if key is None:
                key = jax.random.PRNGKey(0)
            block_size = n // n_actions
            nz_values = jax.random.normal(key, (n,), **kwargs)
            nz_values /= jnp.sqrt(block_size)
        elif n is not None:
            warnings.warn("n is ignored because nz_values is provided")

        if not isinstance(nz_values, nnx.Variable):
            nz_values = Real(nz_values)

        self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
        self._n_actions: int = n_actions

    @property
    @override
    def n_actions(self) -> int:
        """Number of actions to be used."""
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Convert to block diagonal sparse action operators.

        Args:
            A: Linear operator (unused).

        Returns:
            Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.
        """
        return BlockDiagonalSparse(self.nz_values.value, self.n_actions).T

n_actions property

Number of actions to be used.

__init__(n_actions, n=None, nz_values=None, key=None, **kwargs)

Initialize the block sparse policy.

Parameters:

Name Type Description Default
n_actions int

Number of actions to use.

required
n int | None

Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

None
nz_values Float[Array, N] | Variable[Float[Array, N]] | None

Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

None
key PRNGKeyArray | None

Random key for sampling actions if nz_values is not provided.

None
**kwargs

Additional keyword arguments for jax.random.normal (e.g. dtype)

{}
Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

to_actions(A)

Convert to block diagonal sparse action operators.

Parameters:

Name Type Description Default
A LinearOperator

Linear operator (unused).

required

Returns:

Type Description
LinearOperator

Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values.value, self.n_actions).T

ComputationAwareGP

Bases: AbstractComputationAwareGP

Computation-aware Gaussian Process model.

This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.

Attributes:

Name Type Description
posterior

The original (exact) posterior.

policy AbstractBatchLinearSolverPolicy

The batch linear solver policy.

jitter ScalarFloat

Numerical jitter for stability.

Notes
  • Only single-output models are currently supported.
Source code in src/cagpjax/models/cagp.py
class ComputationAwareGP(AbstractComputationAwareGP):
    """Computation-aware Gaussian Process model.

    This model implements scalable GP inference by using batch linear solver
    policies to project the kernel and data to a lower-dimensional subspace, while
    accounting for the extra uncertainty imposed by observing only this subspace.

    Attributes:
        posterior: The original (exact) posterior.
        policy: The batch linear solver policy.
        jitter: Numerical jitter for stability.

    Notes:
        - Only single-output models are currently supported.
    """

    def __init__(
        self,
        posterior: ConjugatePosterior,
        policy: AbstractBatchLinearSolverPolicy,
        jitter: ScalarFloat = 1e-6,
    ):
        """Initialize the Computation-Aware GP model.

        Args:
            posterior: GPJax conjugate posterior.
            policy: The batch linear solver policy that defines the subspace into
                which the data is projected.
            jitter: A small positive constant added to the diagonal of a covariance
                matrix when necessary to ensure numerical stability.
        """
        super().__init__(posterior)
        self.policy: AbstractBatchLinearSolverPolicy = policy
        self.jitter: ScalarFloat = jitter
        self._posterior_params: _ProjectedPosteriorParameters | None = None

    @property
    def is_conditioned(self) -> bool:
        """Whether the model has been conditioned on training data."""
        return self._posterior_params is not None

    def condition(self, train_data: Dataset) -> None:
        """Compute and store the projected quantities of the conditioned GP posterior.

        Args:
            train_data: The training data used to fit the GP.
        """
        # Ensure we have supervised training data
        if train_data.X is None or train_data.y is None:
            raise ValueError("Training data must be supervised.")

        # Unpack training data
        x = jnp.atleast_2d(train_data.X)
        y = jnp.atleast_1d(train_data.y).squeeze()

        # Unpack prior and likelihood
        prior = self.posterior.prior
        likelihood = self.posterior.likelihood

        # Mean and covariance of prior-predictive distribution
        mean_prior = prior.mean_function(x).squeeze()
        # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
        if isinstance(prior.mean_function, Constant):
            mean_prior = mean_prior.astype(prior.mean_function.constant.value.dtype)
        cov_xx = prior.kernel.gram(x)
        obs_cov = diag_like(cov_xx, likelihood.obs_stddev.value**2)
        cov_prior = cov_xx + obs_cov

        # Project quantities to subspace
        proj = self.policy.to_actions(cov_prior).T
        obs_cov_proj = congruence_transform(proj, obs_cov)
        cov_prior_proj = congruence_transform(proj, cov_prior)
        cov_prior_lchol_proj = lower_cholesky(cov_prior_proj, jitter=self.jitter)

        residual_proj = proj @ (y - mean_prior)
        inv_cov_prior_lchol_proj = cola.linalg.inv(cov_prior_lchol_proj)
        repr_weights_proj = inv_cov_prior_lchol_proj.T @ (
            inv_cov_prior_lchol_proj @ residual_proj
        )

        self._posterior_params = _ProjectedPosteriorParameters(
            x=x,
            proj=proj,
            obs_cov_proj=obs_cov_proj,
            cov_prior_lchol_proj=cov_prior_lchol_proj,
            residual_proj=residual_proj,
            repr_weights_proj=repr_weights_proj,
        )

    @override
    def predict(
        self, test_inputs: Float[Array, "N D"] | None = None
    ) -> GaussianDistribution:
        """Compute the predictive distribution of the GP at the test inputs.

        ``condition`` must be called before this method can be used.

        Args:
            test_inputs: The test inputs at which to make predictions. If not provided,
                predictions are made at the training inputs.

        Returns:
            GaussianDistribution: The predictive distribution of the GP at the
                test inputs.
        """
        if not self.is_conditioned:
            raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

        # help out pyright
        assert self._posterior_params is not None

        # Unpack posterior parameters
        x = self._posterior_params.x
        proj = self._posterior_params.proj
        cov_prior_lchol_proj = self._posterior_params.cov_prior_lchol_proj
        repr_weights_proj = self._posterior_params.repr_weights_proj

        # Predictions at test points
        z = test_inputs if test_inputs is not None else x
        prior = self.posterior.prior
        mean_z = prior.mean_function(z).squeeze()
        # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
        if isinstance(prior.mean_function, Constant):
            mean_z = mean_z.astype(prior.mean_function.constant.value.dtype)
        cov_zz = prior.kernel.gram(z)
        cov_xz = cov_zz if test_inputs is None else prior.kernel.cross_covariance(x, z)
        cov_xz_proj = proj @ cov_xz

        # Posterior predictive distribution
        mean_pred = jnp.atleast_1d(mean_z + cov_xz_proj.T @ repr_weights_proj)
        right_shift_factor = cola.linalg.inv(cov_prior_lchol_proj) @ cov_xz_proj
        cov_pred = cov_zz - right_shift_factor.T @ right_shift_factor
        cov_pred = cola.PSD(cov_pred + diag_like(cov_pred, self.jitter))

        return GaussianDistribution(mean_pred, cov_pred)

    def prior_kl(self) -> ScalarFloat:
        r"""Compute KL divergence between CaGP posterior and GP prior..

        Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
        posterior approximation and $p(f)$ is the GP prior.

        ``condition`` must be called before this method can be used.

        Returns:
            KL divergence value (scalar).
        """
        if not self.is_conditioned:
            raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

        # help out pyright
        assert self._posterior_params is not None

        # Unpack posterior parameters
        obs_cov_proj = self._posterior_params.obs_cov_proj
        cov_prior_lchol_proj = self._posterior_params.cov_prior_lchol_proj
        residual_proj = self._posterior_params.residual_proj
        repr_weights_proj = self._posterior_params.repr_weights_proj

        obs_cov_lchol_proj = lower_cholesky(obs_cov_proj, jitter=self.jitter)

        kl = (
            _kl_divergence_from_cholesky(
                residual_proj,
                obs_cov_lchol_proj,
                jnp.zeros_like(residual_proj),
                cov_prior_lchol_proj,
            )
            - 0.5 * congruence_transform(repr_weights_proj.T, obs_cov_proj).squeeze()
        )

        return kl

is_conditioned property

Whether the model has been conditioned on training data.

__init__(posterior, policy, jitter=1e-06)

Initialize the Computation-Aware GP model.

Parameters:

Name Type Description Default
posterior ConjugatePosterior

GPJax conjugate posterior.

required
policy AbstractBatchLinearSolverPolicy

The batch linear solver policy that defines the subspace into which the data is projected.

required
jitter ScalarFloat

A small positive constant added to the diagonal of a covariance matrix when necessary to ensure numerical stability.

1e-06
Source code in src/cagpjax/models/cagp.py
def __init__(
    self,
    posterior: ConjugatePosterior,
    policy: AbstractBatchLinearSolverPolicy,
    jitter: ScalarFloat = 1e-6,
):
    """Initialize the Computation-Aware GP model.

    Args:
        posterior: GPJax conjugate posterior.
        policy: The batch linear solver policy that defines the subspace into
            which the data is projected.
        jitter: A small positive constant added to the diagonal of a covariance
            matrix when necessary to ensure numerical stability.
    """
    super().__init__(posterior)
    self.policy: AbstractBatchLinearSolverPolicy = policy
    self.jitter: ScalarFloat = jitter
    self._posterior_params: _ProjectedPosteriorParameters | None = None

condition(train_data)

Compute and store the projected quantities of the conditioned GP posterior.

Parameters:

Name Type Description Default
train_data Dataset

The training data used to fit the GP.

required
Source code in src/cagpjax/models/cagp.py
def condition(self, train_data: Dataset) -> None:
    """Compute and store the projected quantities of the conditioned GP posterior.

    Args:
        train_data: The training data used to fit the GP.
    """
    # Ensure we have supervised training data
    if train_data.X is None or train_data.y is None:
        raise ValueError("Training data must be supervised.")

    # Unpack training data
    x = jnp.atleast_2d(train_data.X)
    y = jnp.atleast_1d(train_data.y).squeeze()

    # Unpack prior and likelihood
    prior = self.posterior.prior
    likelihood = self.posterior.likelihood

    # Mean and covariance of prior-predictive distribution
    mean_prior = prior.mean_function(x).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        mean_prior = mean_prior.astype(prior.mean_function.constant.value.dtype)
    cov_xx = prior.kernel.gram(x)
    obs_cov = diag_like(cov_xx, likelihood.obs_stddev.value**2)
    cov_prior = cov_xx + obs_cov

    # Project quantities to subspace
    proj = self.policy.to_actions(cov_prior).T
    obs_cov_proj = congruence_transform(proj, obs_cov)
    cov_prior_proj = congruence_transform(proj, cov_prior)
    cov_prior_lchol_proj = lower_cholesky(cov_prior_proj, jitter=self.jitter)

    residual_proj = proj @ (y - mean_prior)
    inv_cov_prior_lchol_proj = cola.linalg.inv(cov_prior_lchol_proj)
    repr_weights_proj = inv_cov_prior_lchol_proj.T @ (
        inv_cov_prior_lchol_proj @ residual_proj
    )

    self._posterior_params = _ProjectedPosteriorParameters(
        x=x,
        proj=proj,
        obs_cov_proj=obs_cov_proj,
        cov_prior_lchol_proj=cov_prior_lchol_proj,
        residual_proj=residual_proj,
        repr_weights_proj=repr_weights_proj,
    )

predict(test_inputs=None)

Compute the predictive distribution of the GP at the test inputs.

condition must be called before this method can be used.

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D'] | None

The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.

None

Returns:

Name Type Description
GaussianDistribution GaussianDistribution

The predictive distribution of the GP at the test inputs.

Source code in src/cagpjax/models/cagp.py
@override
def predict(
    self, test_inputs: Float[Array, "N D"] | None = None
) -> GaussianDistribution:
    """Compute the predictive distribution of the GP at the test inputs.

    ``condition`` must be called before this method can be used.

    Args:
        test_inputs: The test inputs at which to make predictions. If not provided,
            predictions are made at the training inputs.

    Returns:
        GaussianDistribution: The predictive distribution of the GP at the
            test inputs.
    """
    if not self.is_conditioned:
        raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

    # help out pyright
    assert self._posterior_params is not None

    # Unpack posterior parameters
    x = self._posterior_params.x
    proj = self._posterior_params.proj
    cov_prior_lchol_proj = self._posterior_params.cov_prior_lchol_proj
    repr_weights_proj = self._posterior_params.repr_weights_proj

    # Predictions at test points
    z = test_inputs if test_inputs is not None else x
    prior = self.posterior.prior
    mean_z = prior.mean_function(z).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        mean_z = mean_z.astype(prior.mean_function.constant.value.dtype)
    cov_zz = prior.kernel.gram(z)
    cov_xz = cov_zz if test_inputs is None else prior.kernel.cross_covariance(x, z)
    cov_xz_proj = proj @ cov_xz

    # Posterior predictive distribution
    mean_pred = jnp.atleast_1d(mean_z + cov_xz_proj.T @ repr_weights_proj)
    right_shift_factor = cola.linalg.inv(cov_prior_lchol_proj) @ cov_xz_proj
    cov_pred = cov_zz - right_shift_factor.T @ right_shift_factor
    cov_pred = cola.PSD(cov_pred + diag_like(cov_pred, self.jitter))

    return GaussianDistribution(mean_pred, cov_pred)

prior_kl()

Compute KL divergence between CaGP posterior and GP prior..

Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.

condition must be called before this method can be used.

Returns:

Type Description
ScalarFloat

KL divergence value (scalar).

Source code in src/cagpjax/models/cagp.py
def prior_kl(self) -> ScalarFloat:
    r"""Compute KL divergence between CaGP posterior and GP prior..

    Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
    posterior approximation and $p(f)$ is the GP prior.

    ``condition`` must be called before this method can be used.

    Returns:
        KL divergence value (scalar).
    """
    if not self.is_conditioned:
        raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

    # help out pyright
    assert self._posterior_params is not None

    # Unpack posterior parameters
    obs_cov_proj = self._posterior_params.obs_cov_proj
    cov_prior_lchol_proj = self._posterior_params.cov_prior_lchol_proj
    residual_proj = self._posterior_params.residual_proj
    repr_weights_proj = self._posterior_params.repr_weights_proj

    obs_cov_lchol_proj = lower_cholesky(obs_cov_proj, jitter=self.jitter)

    kl = (
        _kl_divergence_from_cholesky(
            residual_proj,
            obs_cov_lchol_proj,
            jnp.zeros_like(residual_proj),
            cov_prior_lchol_proj,
        )
        - 0.5 * congruence_transform(repr_weights_proj.T, obs_cov_proj).squeeze()
    )

    return kl

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

Name Type Description
n_actions int

Number of Lanczos vectors/actions to compute.

key PRNGKeyArray | None

Random key for reproducible Lanczos iterations.

Source code in src/cagpjax/policies/lanczos.py
class LanczosPolicy(AbstractBatchLinearSolverPolicy):
    """Lanczos-based policy for eigenvalue decomposition approximation.

    This policy uses the Lanczos algorithm to compute the top ``n_actions`` eigenvectors
    of the linear operator $A$.

    Attributes:
        n_actions: Number of Lanczos vectors/actions to compute.
        key: Random key for reproducible Lanczos iterations.
    """

    def __init__(self, n_actions: int, key: PRNGKeyArray | None = None):
        """Initialize the Lanczos policy.

        Args:
            n_actions: Number of Lanczos vectors to compute.
            key: Random key for initialization.
        """
        self._n_actions: int = n_actions
        self.key: PRNGKeyArray | None = key

    @property
    @override
    def n_actions(self) -> int:
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Compute action matrix.

        Args:
            A: Symmetric linear operator representing the linear system.

        Returns:
            Linear operator containing the Lanczos vectors as columns.
        """
        vecs = cola.linalg.eig(
            cola.SelfAdjoint(A),
            self.n_actions,
            which="LM",
            alg=cola.linalg.Lanczos(key=self.key),
        )[1]
        if not isinstance(vecs, LinearOperator):
            vecs = Dense(vecs)
        return vecs

__init__(n_actions, key=None)

Initialize the Lanczos policy.

Parameters:

Name Type Description Default
n_actions int

Number of Lanczos vectors to compute.

required
key PRNGKeyArray | None

Random key for initialization.

None
Source code in src/cagpjax/policies/lanczos.py
def __init__(self, n_actions: int, key: PRNGKeyArray | None = None):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
    """
    self._n_actions: int = n_actions
    self.key: PRNGKeyArray | None = key

to_actions(A)

Compute action matrix.

Parameters:

Name Type Description Default
A LinearOperator

Symmetric linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = cola.linalg.eig(
        cola.SelfAdjoint(A),
        self.n_actions,
        which="LM",
        alg=cola.linalg.Lanczos(key=self.key),
    )[1]
    if not isinstance(vecs, LinearOperator):
        vecs = Dense(vecs)
    return vecs