Skip to content

cagpjax.policies

Modules:

Classes:

AbstractBatchLinearSolverPolicy

Bases: AbstractLinearSolverPolicy, ABC

Abstract base class for policies that product action matrices.

Methods:

  • to_actions

    Compute all actions used to solve the linear system \(Ax=b\).

Attributes:

  • n_actions (int) –

    Number of actions in this policy.

n_actions abstractmethod property

n_actions: int

Number of actions in this policy.

to_actions abstractmethod

to_actions(A: LinearOperator) -> LinearOperator

Compute all actions used to solve the linear system \(Ax=b\).

For a matrix \(A\) with shape (n, n), the action matrix has shape (n, n_actions).

Parameters:

  • A

    (LinearOperator) –

    Linear operator representing the linear system.

Returns:

Source code in src/cagpjax/policies/base.py
@abc.abstractmethod
def to_actions(self, A: LinearOperator) -> LinearOperator:
    r"""Compute all actions used to solve the linear system $Ax=b$.

    For a matrix $A$ with shape ``(n, n)``, the action matrix has shape
    ``(n, n_actions)``.

    Args:
        A: Linear operator representing the linear system.

    Returns:
        Linear operator representing the action matrix.
    """
    ...

AbstractLinearSolverPolicy

Bases: Module

Abstract base class for all linear solver policies.

Policies define actions used to solve a linear system \(A x = b\), where \(A\) is a square linear operator.

BlockSparsePolicy

BlockSparsePolicy(n_actions: int, n: int | None = None, nz_values: Float[Array, N] | Variable[Float[Array, N]] | None = None, key: PRNGKeyArray | None = None, **kwargs)

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Initialize the block sparse policy.

Parameters:

  • n_actions

    (int) –

    Number of actions to use.

  • n

    (int | None, default: None ) –

    Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

  • nz_values

    (Float[Array, N] | Variable[Float[Array, N]] | None, default: None ) –

    Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key for sampling actions if nz_values is not provided.

  • **kwargs

    Additional keyword arguments for jax.random.normal (e.g. dtype)

Methods:

  • to_actions

    Convert to block diagonal sparse action operators.

Attributes:

  • n_actions (int) –

    Number of actions to be used.

Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

n_actions property

n_actions: int

Number of actions to be used.

to_actions

to_actions(A: LinearOperator) -> LinearOperator

Convert to block diagonal sparse action operators.

Parameters:

Returns:

  • BlockDiagonalSparse ( LinearOperator ) –

    Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values[...], self.n_actions)

LanczosPolicy

LanczosPolicy(n_actions: int | None, key: PRNGKeyArray | None = None, grad_rtol: float | None = 0.0)

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

  • n_actions (int) –

    Number of Lanczos vectors/actions to compute.

  • key (PRNGKeyArray | None) –

    Random key for reproducible Lanczos iterations.

Initialize the Lanczos policy.

Parameters:

  • n_actions

    (int | None) –

    Number of Lanczos vectors to compute.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key for initialization.

  • grad_rtol

    (float | None, default: 0.0 ) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

Methods:

Source code in src/cagpjax/policies/lanczos.py
def __init__(
    self,
    n_actions: int | None,
    key: PRNGKeyArray | None = None,
    grad_rtol: float | None = 0.0,
):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
            (see [`cagpjax.linalg.eigh`][] for more details)
    """
    self._n_actions: int = n_actions
    self.key = key
    self.grad_rtol = grad_rtol

to_actions

to_actions(A: LinearOperator) -> LinearOperator

Compute action matrix.

Parameters:

  • A

    (LinearOperator) –

    Symmetric linear operator representing the linear system.

Returns:

  • LinearOperator

    Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs

OrthogonalizationPolicy

OrthogonalizationPolicy(base_policy: AbstractBatchLinearSolverPolicy, method: OrthogonalizationMethod = OrthogonalizationMethod.QR, n_reortho: int = 0)

Bases: AbstractBatchLinearSolverPolicy

Orthogonalization policy.

This policy orthogonalizes (if necessary) the action operator produced by the base policy.

Parameters:

  • base_policy

    (AbstractBatchLinearSolverPolicy) –

    The base policy that produces the action operator to be orthogonalized.

  • method

    (OrthogonalizationMethod, default: QR ) –

    The method to use for orthogonalization.

  • n_reortho

    (int, default: 0 ) –

    The number of times to re-orthogonalize each column. Reorthogonalizing once is generally sufficient to improve orthogonality for Gram-Schmidt variants (see e.g. 10.1007/s00211-005-0615-4).

Source code in src/cagpjax/policies/orthogonalization.py
def __init__(
    self,
    base_policy: AbstractBatchLinearSolverPolicy,
    method: OrthogonalizationMethod = OrthogonalizationMethod.QR,
    n_reortho: int = 0,
):
    self.base_policy = base_policy
    self.method = method
    self.n_reortho = n_reortho

PseudoInputPolicy

PseudoInputPolicy(pseudo_inputs: Float[Array, 'M D'] | Parameter[Float[Array, 'M D']], train_inputs_or_dataset: Float[Array, 'N D'] | Dataset, kernel: AbstractKernel)

Bases: AbstractBatchLinearSolverPolicy

Pseudo-input linear solver policy.

This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.

Parameters:

  • pseudo_inputs

    (Float[Array, 'M D'] | Parameter[Float[Array, 'M D']]) –

    Pseudo-inputs for the kernel. If wrapped as a gpjax.parameters.Parameter, they will be treated as trainable.

  • train_inputs

    Training inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.

  • kernel

    (AbstractKernel) –

    Kernel for the GP prior. It must be able to take train_inputs and pseudo_inputs as arguments to its cross_covariance method.

Note

When training with many pseudo-inputs, it is common for the cross-covariance matrix to become poorly conditioned. Performance can be significantly improved by orthogonalizing the actions using an OrthogonalizationPolicy.

Source code in src/cagpjax/policies/pseudoinput.py
def __init__(
    self,
    pseudo_inputs: Float[Array, "M D"] | Parameter[Float[Array, "M D"]],
    train_inputs_or_dataset: Float[Array, "N D"] | gpjax.dataset.Dataset,
    kernel: gpjax.kernels.AbstractKernel,
):
    if isinstance(train_inputs_or_dataset, gpjax.dataset.Dataset):
        train_data = train_inputs_or_dataset
        if train_data.X is None:
            raise ValueError("Dataset must contain training inputs.")
        train_inputs = train_data.X
    else:
        train_inputs = train_inputs_or_dataset
    self.pseudo_inputs = pseudo_inputs
    self.train_inputs = jnp.atleast_2d(train_inputs)
    self.kernel = kernel