Skip to content

cagpjax.policies

AbstractBatchLinearSolverPolicy

Bases: AbstractLinearSolverPolicy, ABC

Abstract base class for policies that product action matrices.

Source code in src/cagpjax/policies/base.py
class AbstractBatchLinearSolverPolicy(AbstractLinearSolverPolicy, abc.ABC):
    """Abstract base class for policies that product action matrices."""

    @property
    @abc.abstractmethod
    def n_actions(self) -> int:
        """Number of actions in this policy."""
        ...

    @abc.abstractmethod
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        r"""Compute all actions used to solve the linear system $Ax=b$.

        For a matrix $A$ with shape ``(n, n)``, the action matrix has shape
        ``(n, n_actions)``.

        Args:
            A: Linear operator representing the linear system.

        Returns:
            Linear operator representing the action matrix.
        """
        ...

n_actions abstractmethod property

Number of actions in this policy.

to_actions(A) abstractmethod

Compute all actions used to solve the linear system \(Ax=b\).

For a matrix \(A\) with shape (n, n), the action matrix has shape (n, n_actions).

Parameters:

Name Type Description Default
A LinearOperator

Linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator representing the action matrix.

Source code in src/cagpjax/policies/base.py
@abc.abstractmethod
def to_actions(self, A: LinearOperator) -> LinearOperator:
    r"""Compute all actions used to solve the linear system $Ax=b$.

    For a matrix $A$ with shape ``(n, n)``, the action matrix has shape
    ``(n, n_actions)``.

    Args:
        A: Linear operator representing the linear system.

    Returns:
        Linear operator representing the action matrix.
    """
    ...

AbstractLinearSolverPolicy

Bases: Module

Abstract base class for all linear solver policies.

Policies define actions used to solve a linear system \(A x = b\), where \(A\) is a square linear operator.

Source code in src/cagpjax/policies/base.py
class AbstractLinearSolverPolicy(nnx.Module):
    r"""Abstract base class for all linear solver policies.

    Policies define actions used to solve a linear system $A x = b$, where $A$ is a
    square linear operator.
    """

    ...

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Source code in src/cagpjax/policies/block_sparse.py
class BlockSparsePolicy(AbstractBatchLinearSolverPolicy):
    r"""Block-sparse linear solver policy.

    This policy uses a fixed block-diagonal sparse structure to define
    independent learnable actions. The matrix has the following structure:

    $$
    S = \begin{bmatrix}
        s_1 & 0 & \cdots & 0 \\
        0 & s_2 & \cdots & 0 \\
        \vdots & \vdots & \ddots & \vdots \\
        0 & 0 & \cdots & s_{\text{n_actions}}
    \end{bmatrix}
    $$

    These are stacked and stored as a single trainable parameter ``nz_values``.
    """

    def __init__(
        self,
        n_actions: int,
        n: int | None = None,
        nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
        key: PRNGKeyArray | None = None,
        **kwargs,
    ):
        """Initialize the block sparse policy.

        Args:
            n_actions: Number of actions to use.
            n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
                not provided.
            nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
                provided, random actions are sampled using the key if provided.
            key: Random key for sampling actions if ``nz_values`` is not provided.
            **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
        """
        if nz_values is None:
            if n is None:
                raise ValueError("n must be provided if nz_values is not provided")
            if key is None:
                key = jax.random.PRNGKey(0)
            block_size = n // n_actions
            nz_values = jax.random.normal(key, (n,), **kwargs)
            nz_values /= jnp.sqrt(block_size)
        elif n is not None:
            warnings.warn("n is ignored because nz_values is provided")

        if not isinstance(nz_values, nnx.Variable):
            nz_values = Real(nz_values)

        self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
        self._n_actions: int = n_actions

    @property
    @override
    def n_actions(self) -> int:
        """Number of actions to be used."""
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Convert to block diagonal sparse action operators.

        Args:
            A: Linear operator (unused).

        Returns:
            BlockDiagonalSparse: Sparse action structure representing the blocks.
        """
        return BlockDiagonalSparse(self.nz_values.value, self.n_actions)

n_actions property

Number of actions to be used.

__init__(n_actions, n=None, nz_values=None, key=None, **kwargs)

Initialize the block sparse policy.

Parameters:

Name Type Description Default
n_actions int

Number of actions to use.

required
n int | None

Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

None
nz_values Float[Array, N] | Variable[Float[Array, N]] | None

Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

None
key PRNGKeyArray | None

Random key for sampling actions if nz_values is not provided.

None
**kwargs

Additional keyword arguments for jax.random.normal (e.g. dtype)

{}
Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

to_actions(A)

Convert to block diagonal sparse action operators.

Parameters:

Name Type Description Default
A LinearOperator

Linear operator (unused).

required

Returns:

Name Type Description
BlockDiagonalSparse LinearOperator

Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values.value, self.n_actions)

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

Name Type Description
n_actions int

Number of Lanczos vectors/actions to compute.

key PRNGKeyArray | None

Random key for reproducible Lanczos iterations.

Source code in src/cagpjax/policies/lanczos.py
class LanczosPolicy(AbstractBatchLinearSolverPolicy):
    """Lanczos-based policy for eigenvalue decomposition approximation.

    This policy uses the Lanczos algorithm to compute the top ``n_actions`` eigenvectors
    of the linear operator $A$.

    Attributes:
        n_actions: Number of Lanczos vectors/actions to compute.
        key: Random key for reproducible Lanczos iterations.
    """

    key: PRNGKeyArray | None
    grad_rtol: float | None

    def __init__(
        self,
        n_actions: int | None,
        key: PRNGKeyArray | None = None,
        grad_rtol: float | None = 0.0,
    ):
        """Initialize the Lanczos policy.

        Args:
            n_actions: Number of Lanczos vectors to compute.
            key: Random key for initialization.
            grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
                gradient computation for (almost-)degenerate matrices.
                If not provided, the default is 0.0.
                If None or negative, all eigenvalues are treated as distinct.
                (see [`cagpjax.linalg.eigh`][] for more details)
        """
        self._n_actions: int = n_actions
        self.key = key
        self.grad_rtol = grad_rtol

    @property
    @override
    def n_actions(self) -> int:
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Compute action matrix.

        Args:
            A: Symmetric linear operator representing the linear system.

        Returns:
            Linear operator containing the Lanczos vectors as columns.
        """
        vecs = eigh(
            A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
        ).eigenvectors
        return vecs

__init__(n_actions, key=None, grad_rtol=0.0)

Initialize the Lanczos policy.

Parameters:

Name Type Description Default
n_actions int | None

Number of Lanczos vectors to compute.

required
key PRNGKeyArray | None

Random key for initialization.

None
grad_rtol float | None

Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

0.0
Source code in src/cagpjax/policies/lanczos.py
def __init__(
    self,
    n_actions: int | None,
    key: PRNGKeyArray | None = None,
    grad_rtol: float | None = 0.0,
):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
            (see [`cagpjax.linalg.eigh`][] for more details)
    """
    self._n_actions: int = n_actions
    self.key = key
    self.grad_rtol = grad_rtol

to_actions(A)

Compute action matrix.

Parameters:

Name Type Description Default
A LinearOperator

Symmetric linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs

PseudoInputPolicy

Bases: AbstractBatchLinearSolverPolicy

Pseudo-input linear solver policy.

This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.

Parameters:

Name Type Description Default
pseudo_inputs Float[Array, 'M D'] | Variable

Pseudo-inputs for the kernel. If wrapped as a gpjax.parameters.Parameter, they will be treated as trainable.

required
train_inputs

Training inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.

required
kernel AbstractKernel

Kernel for the GP prior. It must be able to take train_inputs and pseudo_inputs as arguments to its cross_covariance method.

required
Source code in src/cagpjax/policies/pseudoinput.py
class PseudoInputPolicy(AbstractBatchLinearSolverPolicy):
    """Pseudo-input linear solver policy.

    This policy constructs actions from the cross-covariance between the training inputs and
    pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to
    inducing points and can be marked as trainable.

    Args:
        pseudo_inputs: Pseudo-inputs for the kernel. If wrapped as a `gpjax.parameters.Parameter`,
            they will be treated as trainable.
        train_inputs: Training inputs or a dataset containing training inputs. These must be the
            same inputs in the same order as the training data used to condition the CaGP model.
        kernel: Kernel for the GP prior. It must be able to take `train_inputs` and `pseudo_inputs`
            as arguments to its `cross_covariance` method.
    """

    pseudo_inputs: nnx.Variable
    train_inputs: Float[Array, "N D"]
    kernel: gpjax.kernels.AbstractKernel

    def __init__(
        self,
        pseudo_inputs: Float[Array, "M D"] | nnx.Variable,
        train_inputs_or_dataset: Float[Array, "N D"] | gpjax.dataset.Dataset,
        kernel: gpjax.kernels.AbstractKernel,
    ):
        if isinstance(train_inputs_or_dataset, gpjax.dataset.Dataset):
            train_data = train_inputs_or_dataset
            if train_data.X is None:
                raise ValueError("Dataset must contain training inputs.")
            train_inputs = train_data.X
        else:
            train_inputs = train_inputs_or_dataset
        if not isinstance(pseudo_inputs, nnx.Variable):
            pseudo_inputs = gpjax.parameters.Static(jnp.atleast_2d(pseudo_inputs))
        self.pseudo_inputs = pseudo_inputs
        self.train_inputs = jnp.atleast_2d(train_inputs)
        self.kernel = kernel

    @property
    def n_actions(self):
        return self.pseudo_inputs.shape[0]

    def to_actions(self, A: LinearOperator) -> LinearOperator:
        S = self.kernel.cross_covariance(self.train_inputs, self.pseudo_inputs.value)
        return cola.lazify(S)