cagpjax.policies.block_sparse
Block-sparse policy.
Classes:
-
BlockSparsePolicy–Block-sparse linear solver policy.
BlockSparsePolicy
Bases: AbstractBatchLinearSolverPolicy
Block-sparse linear solver policy.
This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:
These are stacked and stored as a single trainable parameter nz_values.
Attributes:
-
n_actions(int) –Number of actions to use.
-
nz_values(Float[Array, N] | AbstractUnwrappable[Float[Array, N]]) –Non-zero values of the block-diagonal sparse matrix.
Methods:
-
from_random–Initialize policy from block-normalized random samples.
-
to_actions–Convert to block diagonal sparse action operators.
from_random
classmethod
from_random(key: PRNGKeyArray, num_datapoints: int, n_actions: int, *, sampler: Callable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']] = jax.random.normal, dtype: Any = None) -> BlockSparsePolicy
Initialize policy from block-normalized random samples.
Parameters:
-
(keyPRNGKeyArray) –Random key used to sample initial values.
-
(num_datapointsint) –Number of rows in the resulting operator.
-
(n_actionsint) –Number of action columns in the resulting operator.
-
(samplerCallable[[PRNGKeyArray, tuple[int, ...], Any], Float[Array, ' N']], default:normal) –Callable with signature
(key, shape, dtype) -> values. -
(dtypeAny, default:None) –Optional dtype forwarded to
sampler.
Source code in src/cagpjax/policies/block_sparse.py
to_actions
Convert to block diagonal sparse action operators.
Parameters:
-
(ALinearOperator) –Linear operator (unused).
-
(keyPRNGKeyArray | None, default:None) –Optional random key (unused).
Returns:
-
BlockDiagonalSparse(LinearOperator) –Sparse action structure representing the blocks.