Skip to content

cagpjax.policies.block_sparse

Block-sparse policy.

Classes:

BlockSparsePolicy

Bases: AbstractBatchLinearSolverPolicy

Block-sparse linear solver policy.

This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:

\[ S = \begin{bmatrix} s_1 & 0 & \cdots & 0 \\ 0 & s_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & s_{\text{n_actions}} \end{bmatrix} \]

These are stacked and stored as a single trainable parameter nz_values.

Attributes:

  • n_actions (int) –

    Number of actions to use.

  • nz_values (Float[Array, N] | AbstractUnwrappable[Float[Array, N]]) –

    Non-zero values of the block-diagonal sparse matrix.

Methods:

  • from_random

    Initialize policy from block-normalized random samples.

  • to_actions

    Convert to block diagonal sparse action operators.

from_random classmethod

from_random(key: PRNGKeyArray, num_datapoints: int, n_actions: int, *, sampler: Callable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']] = jax.random.normal, dtype: Any = None) -> BlockSparsePolicy

Initialize policy from block-normalized random samples.

Parameters:

  • key

    (PRNGKeyArray) –

    Random key used to sample initial values.

  • num_datapoints

    (int) –

    Number of rows in the resulting operator.

  • n_actions

    (int) –

    Number of action columns in the resulting operator.

  • sampler

    (Callable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']], default: normal ) –

    Callable with signature (key, shape, dtype) -> values.

  • dtype

    (Any, default: None ) –

    Optional dtype forwarded to sampler.

Source code in src/cagpjax/policies/block_sparse.py
@classmethod
def from_random(
    cls,
    key: PRNGKeyArray,
    num_datapoints: int,
    n_actions: int,
    *,
    sampler: Callable[
        [PRNGKeyArray, tuple[int, ...], Any], Float[Array, " N"]
    ] = jax.random.normal,
    dtype: Any = None,
) -> "BlockSparsePolicy":
    """Initialize policy from block-normalized random samples.

    Args:
        key: Random key used to sample initial values.
        num_datapoints: Number of rows in the resulting operator.
        n_actions: Number of action columns in the resulting operator.
        sampler: Callable with signature ``(key, shape, dtype) -> values``.
        dtype: Optional dtype forwarded to ``sampler``.
    """
    if num_datapoints < 1:
        raise ValueError("num_datapoints must be at least 1")
    nz_values = sampler(key, (num_datapoints,), dtype)
    nz_values = _normalize_by_blocks(nz_values, n_actions)
    return cls(n_actions=n_actions, nz_values=nz_values)

to_actions

to_actions(A: LinearOperator, *, key: PRNGKeyArray | None = None) -> LinearOperator

Convert to block diagonal sparse action operators.

Parameters:

  • A

    (LinearOperator) –

    Linear operator (unused).

  • key

    (PRNGKeyArray | None, default: None ) –

    Optional random key (unused).

Returns:

  • BlockDiagonalSparse ( LinearOperator ) –

    Sparse action structure representing the blocks.

Source code in src/cagpjax/policies/block_sparse.py
@override
def to_actions(
    self, A: LinearOperator, *, key: PRNGKeyArray | None = None
) -> LinearOperator:
    """Convert to block diagonal sparse action operators.

    Args:
        A: Linear operator (unused).
        key: Optional random key (unused).

    Returns:
        BlockDiagonalSparse: Sparse action structure representing the blocks.
    """
    return BlockDiagonalSparse(paramax.unwrap(self.nz_values), self.n_actions)