Skip to content

cagpjax

Computation-Aware Gaussian Processes for GPJax.

BlockDiagonalSparse

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.

Parameters:

Name Type Description Default
nz_values Float[Array, N]

Non-zero values to be distributed across diagonal blocks.

required
n_blocks int

Number of diagonal blocks in the matrix.

required

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
       [2., 0., 0.],
       [0., 3., 0.],
       [0., 4., 0.],
       [0., 0., 5.],
       [0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
class BlockDiagonalSparse(LinearOperator):
    """Block-diagonal sparse linear operator.

    This operator represents a block-diagonal matrix structure where the blocks are contiguous, and
    each contains a column vector, so that exactly one value is non-zero in each row.

    Args:
        nz_values: Non-zero values to be distributed across diagonal blocks.
        n_blocks: Number of diagonal blocks in the matrix.

    Examples
    --------
    ```python
    >>> import jax.numpy as jnp
    >>> from cagpjax.operators import BlockDiagonalSparse
    >>>
    >>> # Create a 3x6 block-diagonal matrix with 3 blocks
    >>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    >>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
    >>> print(op.shape)
    (6, 3)
    >>>
    >>> # Apply to identity matrices
    >>> op @ jnp.eye(3)
    Array([[1., 0., 0.],
           [2., 0., 0.],
           [0., 3., 0.],
           [0., 4., 0.],
           [0., 0., 5.],
           [0., 0., 6.]], dtype=float32)
    ```
    """

    def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
        n = nz_values.shape[0]
        super().__init__(nz_values.dtype, (n, n_blocks))
        self.nz_values = nz_values

    def _matmat(self, X: Float[Array, "K M"]) -> Float[Array, "N M"]:
        n, n_blocks = self.shape
        block_size = n // n_blocks
        n_blocks_main = n_blocks if n % n_blocks == 0 else n_blocks - 1
        n_main = n_blocks_main * block_size
        m = X.shape[1]

        # block-wise multiplication for main blocks
        blocks_main = self.nz_values[:n_main].reshape(n_blocks_main, block_size)
        X_main = X[:n_blocks_main, :]
        res_main = (blocks_main[..., None] * X_main[:, None, :]).reshape(n_main, m)

        # handle overhang if any
        if n > n_main:
            n_overhang = n - n_main
            X_overhang = X[n_blocks_main, :]
            block_overhang = self.nz_values[n_main:]
            res_overhang = jnp.outer(block_overhang, X_overhang).reshape(n_overhang, m)
            res = jnp.concatenate([res_main, res_overhang], axis=0)
        else:
            res = res_main

        return res

    def _rmatmat(self, X: Float[Array, "M N"]) -> Float[Array, "M K"]:
        # figure out size of main blocks
        n, n_blocks = self.shape
        block_size = n // n_blocks
        n_blocks_main = n_blocks if n % n_blocks == 0 else n_blocks - 1
        n_main = n_blocks_main * block_size
        m = X.shape[0]

        # block-wise multiplication for main blocks
        blocks_main = self.nz_values[:n_main].reshape(n_blocks_main, block_size)
        X_main = X[:, :n_main].reshape(m, n_blocks_main, block_size)
        res_main = jnp.einsum("ik,jik->ji", blocks_main, X_main)

        # handle overhang if any
        if n > n_main:
            n_overhang = n - n_main
            X_overhang = X[:, n_main:].reshape(m, n_overhang)
            block_overhang = self.nz_values[n_main:]
            res_overhang = (X_overhang @ block_overhang)[:, None]
            res = jnp.concatenate([res_main, res_overhang], axis=1)
        else:
            res = res_main

        return res

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Source code in src/cagpjax/policies/block_sparse.py
class BlockSparsePolicy(AbstractBatchLinearSolverPolicy):
    r"""Block-sparse linear solver policy.

    This policy uses a fixed block-diagonal sparse structure to define
    independent learnable actions. The matrix has the following structure:

    $$
    S = \begin{bmatrix}
        s_1 & 0 & \cdots & 0 \\
        0 & s_2 & \cdots & 0 \\
        \vdots & \vdots & \ddots & \vdots \\
        0 & 0 & \cdots & s_{\text{n_actions}}
    \end{bmatrix}
    $$

    These are stacked and stored as a single trainable parameter ``nz_values``.
    """

    def __init__(
        self,
        n_actions: int,
        n: int | None = None,
        nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
        key: PRNGKeyArray | None = None,
        **kwargs,
    ):
        """Initialize the block sparse policy.

        Args:
            n_actions: Number of actions to use.
            n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
                not provided.
            nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
                provided, random actions are sampled using the key if provided.
            key: Random key for sampling actions if ``nz_values`` is not provided.
            **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
        """
        if nz_values is None:
            if n is None:
                raise ValueError("n must be provided if nz_values is not provided")
            if key is None:
                key = jax.random.PRNGKey(0)
            block_size = n // n_actions
            nz_values = jax.random.normal(key, (n,), **kwargs)
            nz_values /= jnp.sqrt(block_size)
        elif n is not None:
            warnings.warn("n is ignored because nz_values is provided")

        if not isinstance(nz_values, nnx.Variable):
            nz_values = Real(nz_values)

        self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
        self._n_actions: int = n_actions

    @property
    @override
    def n_actions(self) -> int:
        """Number of actions to be used."""
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Convert to block diagonal sparse action operators.

        Args:
            A: Linear operator (unused).

        Returns:
            BlockDiagonalSparse: Sparse action structure representing the blocks.
        """
        return BlockDiagonalSparse(self.nz_values.value, self.n_actions)

n_actions property

Number of actions to be used.

__init__(n_actions, n=None, nz_values=None, key=None, **kwargs)

Initialize the block sparse policy.

Parameters:

Name Type Description Default
n_actions int

Number of actions to use.

required
n int | None

Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

None
nz_values Float[Array, N] | Variable[Float[Array, N]] | None

Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

None
key PRNGKeyArray | None

Random key for sampling actions if nz_values is not provided.

None
**kwargs

Additional keyword arguments for jax.random.normal (e.g. dtype)

{}
Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

to_actions(A)

Convert to block diagonal sparse action operators.

Parameters:

Name Type Description Default
A LinearOperator

Linear operator (unused).

required

Returns:

Name Type Description
BlockDiagonalSparse LinearOperator

Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values.value, self.n_actions)

ComputationAwareGP

Bases: AbstractComputationAwareGP

Computation-aware Gaussian Process model.

This model implements scalable GP inference by using batch linear solver policies to project the kernel and data to a lower-dimensional subspace, while accounting for the extra uncertainty imposed by observing only this subspace.

Attributes:

Name Type Description
posterior ConjugatePosterior

The original (exact) posterior.

policy AbstractBatchLinearSolverPolicy

The batch linear solver policy.

solver_method AbstractLinearSolverMethod

The linear solver method to use for solving linear systems with positive semi-definite operators.

Notes
  • Only single-output models are currently supported.
Source code in src/cagpjax/models/cagp.py
class ComputationAwareGP(AbstractComputationAwareGP):
    """Computation-aware Gaussian Process model.

    This model implements scalable GP inference by using batch linear solver
    policies to project the kernel and data to a lower-dimensional subspace, while
    accounting for the extra uncertainty imposed by observing only this subspace.

    Attributes:
        posterior: The original (exact) posterior.
        policy: The batch linear solver policy.
        solver_method: The linear solver method to use for solving linear systems
            with positive semi-definite operators.

    Notes:
        - Only single-output models are currently supported.
    """

    posterior: ConjugatePosterior
    policy: AbstractBatchLinearSolverPolicy
    solver_method: AbstractLinearSolverMethod

    def __init__(
        self,
        posterior: ConjugatePosterior,
        policy: AbstractBatchLinearSolverPolicy,
        solver_method: AbstractLinearSolverMethod = Cholesky(1e-6),
    ):
        """Initialize the Computation-Aware GP model.

        Args:
            posterior: GPJax conjugate posterior.
            policy: The batch linear solver policy that defines the subspace into
                which the data is projected.
            solver_method: The linear solver method to use for solving linear systems with
                positive semi-definite operators.
        """
        super().__init__(posterior)
        self.policy = policy
        self.solver_method = solver_method
        self._posterior_params: _ProjectedPosteriorParameters | None = None

    @property
    def is_conditioned(self) -> bool:
        """Whether the model has been conditioned on training data."""
        return self._posterior_params is not None

    def condition(self, train_data: Dataset) -> None:
        """Compute and store the projected quantities of the conditioned GP posterior.

        Args:
            train_data: The training data used to fit the GP.
        """
        # Ensure we have supervised training data
        if train_data.X is None or train_data.y is None:
            raise ValueError("Training data must be supervised.")

        # Unpack training data
        x = jnp.atleast_2d(train_data.X)
        y = jnp.atleast_1d(train_data.y).squeeze()

        # Unpack prior and likelihood
        prior = self.posterior.prior
        likelihood = self.posterior.likelihood

        # Mean and covariance of prior-predictive distribution
        mean_prior = prior.mean_function(x).squeeze()
        # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
        if isinstance(prior.mean_function, Constant):
            mean_prior = mean_prior.astype(prior.mean_function.constant.value.dtype)
        cov_xx = prior.kernel.gram(x)
        obs_cov = diag_like(cov_xx, likelihood.obs_stddev.value**2)
        cov_prior = cov_xx + obs_cov

        # Project quantities to subspace
        actions = self.policy.to_actions(cov_prior)
        obs_cov_proj = congruence_transform(actions, obs_cov)
        cov_prior_proj = congruence_transform(actions, cov_prior)
        cov_prior_proj_solver = self.solver_method(cov_prior_proj)

        residual_proj = actions.T @ (y - mean_prior)
        repr_weights_proj = cov_prior_proj_solver.solve(residual_proj)

        self._posterior_params = _ProjectedPosteriorParameters(
            x=x,
            actions=actions,
            obs_cov_proj=obs_cov_proj,
            cov_prior_proj_solver=cov_prior_proj_solver,
            residual_proj=residual_proj,
            repr_weights_proj=repr_weights_proj,
        )

    @override
    def predict(
        self, test_inputs: Float[Array, "N D"] | None = None
    ) -> GaussianDistribution:
        """Compute the predictive distribution of the GP at the test inputs.

        ``condition`` must be called before this method can be used.

        Args:
            test_inputs: The test inputs at which to make predictions. If not provided,
                predictions are made at the training inputs.

        Returns:
            GaussianDistribution: The predictive distribution of the GP at the
                test inputs.
        """
        if not self.is_conditioned:
            raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

        # help out pyright
        assert self._posterior_params is not None

        # Unpack posterior parameters
        x = self._posterior_params.x
        actions = self._posterior_params.actions
        cov_prior_proj_solver = self._posterior_params.cov_prior_proj_solver
        repr_weights_proj = self._posterior_params.repr_weights_proj

        # Predictions at test points
        z = test_inputs if test_inputs is not None else x
        prior = self.posterior.prior
        mean_z = prior.mean_function(z).squeeze()
        # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
        if isinstance(prior.mean_function, Constant):
            mean_z = mean_z.astype(prior.mean_function.constant.value.dtype)
        cov_zz = prior.kernel.gram(z)
        cov_zx = cov_zz if test_inputs is None else prior.kernel.cross_covariance(z, x)
        cov_zx_proj = cov_zx @ actions

        # Posterior predictive distribution
        mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ repr_weights_proj)
        cov_pred = cov_zz - cov_prior_proj_solver.inv_congruence_transform(
            cov_zx_proj.T
        )
        cov_pred = cola.PSD(cov_pred + diag_like(cov_pred, self.posterior.jitter))

        return GaussianDistribution(mean_pred, cov_pred)

    def prior_kl(self) -> ScalarFloat:
        r"""Compute KL divergence between CaGP posterior and GP prior..

        Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
        posterior approximation and $p(f)$ is the GP prior.

        ``condition`` must be called before this method can be used.

        Returns:
            KL divergence value (scalar).
        """
        if not self.is_conditioned:
            raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

        # help out pyright
        assert self._posterior_params is not None

        # Unpack posterior parameters
        obs_cov_proj = self._posterior_params.obs_cov_proj
        cov_prior_proj_solver = self._posterior_params.cov_prior_proj_solver
        residual_proj = self._posterior_params.residual_proj
        repr_weights_proj = self._posterior_params.repr_weights_proj

        obs_cov_proj_solver = self.solver_method(obs_cov_proj)

        kl = (
            _kl_divergence_from_solvers(
                residual_proj,
                obs_cov_proj_solver,
                jnp.zeros_like(residual_proj),
                cov_prior_proj_solver,
            )
            - 0.5 * congruence_transform(repr_weights_proj.T, obs_cov_proj).squeeze()
        )

        return kl

is_conditioned property

Whether the model has been conditioned on training data.

__init__(posterior, policy, solver_method=Cholesky(1e-06))

Initialize the Computation-Aware GP model.

Parameters:

Name Type Description Default
posterior ConjugatePosterior

GPJax conjugate posterior.

required
policy AbstractBatchLinearSolverPolicy

The batch linear solver policy that defines the subspace into which the data is projected.

required
solver_method AbstractLinearSolverMethod

The linear solver method to use for solving linear systems with positive semi-definite operators.

Cholesky(1e-06)
Source code in src/cagpjax/models/cagp.py
def __init__(
    self,
    posterior: ConjugatePosterior,
    policy: AbstractBatchLinearSolverPolicy,
    solver_method: AbstractLinearSolverMethod = Cholesky(1e-6),
):
    """Initialize the Computation-Aware GP model.

    Args:
        posterior: GPJax conjugate posterior.
        policy: The batch linear solver policy that defines the subspace into
            which the data is projected.
        solver_method: The linear solver method to use for solving linear systems with
            positive semi-definite operators.
    """
    super().__init__(posterior)
    self.policy = policy
    self.solver_method = solver_method
    self._posterior_params: _ProjectedPosteriorParameters | None = None

condition(train_data)

Compute and store the projected quantities of the conditioned GP posterior.

Parameters:

Name Type Description Default
train_data Dataset

The training data used to fit the GP.

required
Source code in src/cagpjax/models/cagp.py
def condition(self, train_data: Dataset) -> None:
    """Compute and store the projected quantities of the conditioned GP posterior.

    Args:
        train_data: The training data used to fit the GP.
    """
    # Ensure we have supervised training data
    if train_data.X is None or train_data.y is None:
        raise ValueError("Training data must be supervised.")

    # Unpack training data
    x = jnp.atleast_2d(train_data.X)
    y = jnp.atleast_1d(train_data.y).squeeze()

    # Unpack prior and likelihood
    prior = self.posterior.prior
    likelihood = self.posterior.likelihood

    # Mean and covariance of prior-predictive distribution
    mean_prior = prior.mean_function(x).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        mean_prior = mean_prior.astype(prior.mean_function.constant.value.dtype)
    cov_xx = prior.kernel.gram(x)
    obs_cov = diag_like(cov_xx, likelihood.obs_stddev.value**2)
    cov_prior = cov_xx + obs_cov

    # Project quantities to subspace
    actions = self.policy.to_actions(cov_prior)
    obs_cov_proj = congruence_transform(actions, obs_cov)
    cov_prior_proj = congruence_transform(actions, cov_prior)
    cov_prior_proj_solver = self.solver_method(cov_prior_proj)

    residual_proj = actions.T @ (y - mean_prior)
    repr_weights_proj = cov_prior_proj_solver.solve(residual_proj)

    self._posterior_params = _ProjectedPosteriorParameters(
        x=x,
        actions=actions,
        obs_cov_proj=obs_cov_proj,
        cov_prior_proj_solver=cov_prior_proj_solver,
        residual_proj=residual_proj,
        repr_weights_proj=repr_weights_proj,
    )

predict(test_inputs=None)

Compute the predictive distribution of the GP at the test inputs.

condition must be called before this method can be used.

Parameters:

Name Type Description Default
test_inputs Float[Array, 'N D'] | None

The test inputs at which to make predictions. If not provided, predictions are made at the training inputs.

None

Returns:

Name Type Description
GaussianDistribution GaussianDistribution

The predictive distribution of the GP at the test inputs.

Source code in src/cagpjax/models/cagp.py
@override
def predict(
    self, test_inputs: Float[Array, "N D"] | None = None
) -> GaussianDistribution:
    """Compute the predictive distribution of the GP at the test inputs.

    ``condition`` must be called before this method can be used.

    Args:
        test_inputs: The test inputs at which to make predictions. If not provided,
            predictions are made at the training inputs.

    Returns:
        GaussianDistribution: The predictive distribution of the GP at the
            test inputs.
    """
    if not self.is_conditioned:
        raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

    # help out pyright
    assert self._posterior_params is not None

    # Unpack posterior parameters
    x = self._posterior_params.x
    actions = self._posterior_params.actions
    cov_prior_proj_solver = self._posterior_params.cov_prior_proj_solver
    repr_weights_proj = self._posterior_params.repr_weights_proj

    # Predictions at test points
    z = test_inputs if test_inputs is not None else x
    prior = self.posterior.prior
    mean_z = prior.mean_function(z).squeeze()
    # Work around GPJax promoting dtype of mean to float64 (See JaxGaussianProcesses/GPJax#523)
    if isinstance(prior.mean_function, Constant):
        mean_z = mean_z.astype(prior.mean_function.constant.value.dtype)
    cov_zz = prior.kernel.gram(z)
    cov_zx = cov_zz if test_inputs is None else prior.kernel.cross_covariance(z, x)
    cov_zx_proj = cov_zx @ actions

    # Posterior predictive distribution
    mean_pred = jnp.atleast_1d(mean_z + cov_zx_proj @ repr_weights_proj)
    cov_pred = cov_zz - cov_prior_proj_solver.inv_congruence_transform(
        cov_zx_proj.T
    )
    cov_pred = cola.PSD(cov_pred + diag_like(cov_pred, self.posterior.jitter))

    return GaussianDistribution(mean_pred, cov_pred)

prior_kl()

Compute KL divergence between CaGP posterior and GP prior..

Calculates \(\mathrm{KL}[q(f) || p(f)]\), where \(q(f)\) is the CaGP posterior approximation and \(p(f)\) is the GP prior.

condition must be called before this method can be used.

Returns:

Type Description
ScalarFloat

KL divergence value (scalar).

Source code in src/cagpjax/models/cagp.py
def prior_kl(self) -> ScalarFloat:
    r"""Compute KL divergence between CaGP posterior and GP prior..

    Calculates $\mathrm{KL}[q(f) || p(f)]$, where $q(f)$ is the CaGP
    posterior approximation and $p(f)$ is the GP prior.

    ``condition`` must be called before this method can be used.

    Returns:
        KL divergence value (scalar).
    """
    if not self.is_conditioned:
        raise ValueError("Model is not yet conditioned. Call ``condition`` first.")

    # help out pyright
    assert self._posterior_params is not None

    # Unpack posterior parameters
    obs_cov_proj = self._posterior_params.obs_cov_proj
    cov_prior_proj_solver = self._posterior_params.cov_prior_proj_solver
    residual_proj = self._posterior_params.residual_proj
    repr_weights_proj = self._posterior_params.repr_weights_proj

    obs_cov_proj_solver = self.solver_method(obs_cov_proj)

    kl = (
        _kl_divergence_from_solvers(
            residual_proj,
            obs_cov_proj_solver,
            jnp.zeros_like(residual_proj),
            cov_prior_proj_solver,
        )
        - 0.5 * congruence_transform(repr_weights_proj.T, obs_cov_proj).squeeze()
    )

    return kl

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

Name Type Description
n_actions int

Number of Lanczos vectors/actions to compute.

key PRNGKeyArray | None

Random key for reproducible Lanczos iterations.

Source code in src/cagpjax/policies/lanczos.py
class LanczosPolicy(AbstractBatchLinearSolverPolicy):
    """Lanczos-based policy for eigenvalue decomposition approximation.

    This policy uses the Lanczos algorithm to compute the top ``n_actions`` eigenvectors
    of the linear operator $A$.

    Attributes:
        n_actions: Number of Lanczos vectors/actions to compute.
        key: Random key for reproducible Lanczos iterations.
    """

    key: PRNGKeyArray | None
    grad_rtol: float | None

    def __init__(
        self,
        n_actions: int | None,
        key: PRNGKeyArray | None = None,
        grad_rtol: float | None = 0.0,
    ):
        """Initialize the Lanczos policy.

        Args:
            n_actions: Number of Lanczos vectors to compute.
            key: Random key for initialization.
            grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
                gradient computation for (almost-)degenerate matrices.
                If not provided, the default is 0.0.
                If None or negative, all eigenvalues are treated as distinct.
                (see [`cagpjax.linalg.eigh`][] for more details)
        """
        self._n_actions: int = n_actions
        self.key = key
        self.grad_rtol = grad_rtol

    @property
    @override
    def n_actions(self) -> int:
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Compute action matrix.

        Args:
            A: Symmetric linear operator representing the linear system.

        Returns:
            Linear operator containing the Lanczos vectors as columns.
        """
        vecs = eigh(
            A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
        ).eigenvectors
        return vecs

__init__(n_actions, key=None, grad_rtol=0.0)

Initialize the Lanczos policy.

Parameters:

Name Type Description Default
n_actions int | None

Number of Lanczos vectors to compute.

required
key PRNGKeyArray | None

Random key for initialization.

None
grad_rtol float | None

Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

0.0
Source code in src/cagpjax/policies/lanczos.py
def __init__(
    self,
    n_actions: int | None,
    key: PRNGKeyArray | None = None,
    grad_rtol: float | None = 0.0,
):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
            (see [`cagpjax.linalg.eigh`][] for more details)
    """
    self._n_actions: int = n_actions
    self.key = key
    self.grad_rtol = grad_rtol

to_actions(A)

Compute action matrix.

Parameters:

Name Type Description Default
A LinearOperator

Symmetric linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs