Skip to content

cagpjax.policies

AbstractBatchLinearSolverPolicy

Bases: AbstractLinearSolverPolicy, ABC

Abstract base class for policies that product action matrices.

Source code in src/cagpjax/policies/base.py
class AbstractBatchLinearSolverPolicy(AbstractLinearSolverPolicy, abc.ABC):
    """Abstract base class for policies that product action matrices."""

    @property
    @abc.abstractmethod
    def n_actions(self) -> int:
        """Number of actions in this policy."""
        ...

    @abc.abstractmethod
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        r"""Compute all actions used to solve the linear system $Ax=b$.

        For a matrix $A$ with shape ``(n, n)``, the action matrix has shape
        ``(n, n_actions)``.

        Args:
            A: Linear operator representing the linear system.

        Returns:
            Linear operator representing the action matrix.
        """
        ...

n_actions abstractmethod property

Number of actions in this policy.

to_actions(A) abstractmethod

Compute all actions used to solve the linear system \(Ax=b\).

For a matrix \(A\) with shape (n, n), the action matrix has shape (n, n_actions).

Parameters:

Name Type Description Default
A LinearOperator

Linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator representing the action matrix.

Source code in src/cagpjax/policies/base.py
@abc.abstractmethod
def to_actions(self, A: LinearOperator) -> LinearOperator:
    r"""Compute all actions used to solve the linear system $Ax=b$.

    For a matrix $A$ with shape ``(n, n)``, the action matrix has shape
    ``(n, n_actions)``.

    Args:
        A: Linear operator representing the linear system.

    Returns:
        Linear operator representing the action matrix.
    """
    ...

AbstractLinearSolverPolicy

Bases: Module

Abstract base class for all linear solver policies.

Policies define actions used to solve a linear system \(A x = b\), where \(A\) is a square linear operator.

Source code in src/cagpjax/policies/base.py
class AbstractLinearSolverPolicy(nnx.Module):
    r"""Abstract base class for all linear solver policies.

    Policies define actions used to solve a linear system $A x = b$, where $A$ is a
    square linear operator.
    """

    ...

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Source code in src/cagpjax/policies/block_sparse.py
class BlockSparsePolicy(AbstractBatchLinearSolverPolicy):
    r"""Block-sparse linear solver policy.

    This policy uses a fixed block-diagonal sparse structure to define
    independent learnable actions. The matrix has the following structure:

    $$
    S = \begin{bmatrix}
        s_1 & 0 & \cdots & 0 \\
        0 & s_2 & \cdots & 0 \\
        \vdots & \vdots & \ddots & \vdots \\
        0 & 0 & \cdots & s_{\text{n_actions}}
    \end{bmatrix}
    $$

    These are stacked and stored as a single trainable parameter ``nz_values``.
    """

    def __init__(
        self,
        n_actions: int,
        n: int | None = None,
        nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
        key: PRNGKeyArray | None = None,
        **kwargs,
    ):
        """Initialize the block sparse policy.

        Args:
            n_actions: Number of actions to use.
            n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
                not provided.
            nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
                provided, random actions are sampled using the key if provided.
            key: Random key for sampling actions if ``nz_values`` is not provided.
            **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
        """
        if nz_values is None:
            if n is None:
                raise ValueError("n must be provided if nz_values is not provided")
            if key is None:
                key = jax.random.PRNGKey(0)
            block_size = n // n_actions
            nz_values = jax.random.normal(key, (n,), **kwargs)
            nz_values /= jnp.sqrt(block_size)
        elif n is not None:
            warnings.warn("n is ignored because nz_values is provided")

        if not isinstance(nz_values, nnx.Variable):
            nz_values = Real(nz_values)

        self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
        self._n_actions: int = n_actions

    @property
    @override
    def n_actions(self) -> int:
        """Number of actions to be used."""
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Convert to block diagonal sparse action operators.

        Args:
            A: Linear operator (unused).

        Returns:
            Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.
        """
        return BlockDiagonalSparse(self.nz_values.value, self.n_actions).T

n_actions property

Number of actions to be used.

__init__(n_actions, n=None, nz_values=None, key=None, **kwargs)

Initialize the block sparse policy.

Parameters:

Name Type Description Default
n_actions int

Number of actions to use.

required
n int | None

Number of rows and columns of the full operator. Must be provided if nz_values is not provided.

None
nz_values Float[Array, N] | Variable[Float[Array, N]] | None

Non-zero values of the block-diagonal sparse matrix (shape (n,)). If not provided, random actions are sampled using the key if provided.

None
key PRNGKeyArray | None

Random key for sampling actions if nz_values is not provided.

None
**kwargs

Additional keyword arguments for jax.random.normal (e.g. dtype)

{}
Source code in src/cagpjax/policies/block_sparse.py
def __init__(
    self,
    n_actions: int,
    n: int | None = None,
    nz_values: Float[Array, "N"] | nnx.Variable[Float[Array, "N"]] | None = None,
    key: PRNGKeyArray | None = None,
    **kwargs,
):
    """Initialize the block sparse policy.

    Args:
        n_actions: Number of actions to use.
        n: Number of rows and columns of the full operator. Must be provided if ``nz_values`` is
            not provided.
        nz_values: Non-zero values of the block-diagonal sparse matrix (shape ``(n,)``). If not
            provided, random actions are sampled using the key if provided.
        key: Random key for sampling actions if ``nz_values`` is not provided.
        **kwargs: Additional keyword arguments for ``jax.random.normal`` (e.g. ``dtype``)
    """
    if nz_values is None:
        if n is None:
            raise ValueError("n must be provided if nz_values is not provided")
        if key is None:
            key = jax.random.PRNGKey(0)
        block_size = n // n_actions
        nz_values = jax.random.normal(key, (n,), **kwargs)
        nz_values /= jnp.sqrt(block_size)
    elif n is not None:
        warnings.warn("n is ignored because nz_values is provided")

    if not isinstance(nz_values, nnx.Variable):
        nz_values = Real(nz_values)

    self.nz_values: nnx.Variable[Float[Array, "N"]] = nz_values
    self._n_actions: int = n_actions

to_actions(A)

Convert to block diagonal sparse action operators.

Parameters:

Name Type Description Default
A LinearOperator

Linear operator (unused).

required

Returns:

Type Description
LinearOperator

Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).

    Returns:
        Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(self.nz_values.value, self.n_actions).T

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

Name Type Description
n_actions int

Number of Lanczos vectors/actions to compute.

key PRNGKeyArray | None

Random key for reproducible Lanczos iterations.

Source code in src/cagpjax/policies/lanczos.py
class LanczosPolicy(AbstractBatchLinearSolverPolicy):
    """Lanczos-based policy for eigenvalue decomposition approximation.

    This policy uses the Lanczos algorithm to compute the top ``n_actions`` eigenvectors
    of the linear operator $A$.

    Attributes:
        n_actions: Number of Lanczos vectors/actions to compute.
        key: Random key for reproducible Lanczos iterations.
    """

    def __init__(self, n_actions: int, key: PRNGKeyArray | None = None):
        """Initialize the Lanczos policy.

        Args:
            n_actions: Number of Lanczos vectors to compute.
            key: Random key for initialization.
        """
        self._n_actions: int = n_actions
        self.key: PRNGKeyArray | None = key

    @property
    @override
    def n_actions(self) -> int:
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Compute action matrix.

        Args:
            A: Symmetric linear operator representing the linear system.

        Returns:
            Linear operator containing the Lanczos vectors as columns.
        """
        vecs = cola.linalg.eig(
            cola.SelfAdjoint(A),
            self.n_actions,
            which="LM",
            alg=cola.linalg.Lanczos(key=self.key),
        )[1]
        if not isinstance(vecs, LinearOperator):
            vecs = Dense(vecs)
        return vecs

__init__(n_actions, key=None)

Initialize the Lanczos policy.

Parameters:

Name Type Description Default
n_actions int

Number of Lanczos vectors to compute.

required
key PRNGKeyArray | None

Random key for initialization.

None
Source code in src/cagpjax/policies/lanczos.py
def __init__(self, n_actions: int, key: PRNGKeyArray | None = None):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
    """
    self._n_actions: int = n_actions
    self.key: PRNGKeyArray | None = key

to_actions(A)

Compute action matrix.

Parameters:

Name Type Description Default
A LinearOperator

Symmetric linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = cola.linalg.eig(
        cola.SelfAdjoint(A),
        self.n_actions,
        which="LM",
        alg=cola.linalg.Lanczos(key=self.key),
    )[1]
    if not isinstance(vecs, LinearOperator):
        vecs = Dense(vecs)
    return vecs