cagpjax.policies.block_sparse
Block-sparse policy.
Classes:
-
BlockSparsePolicy–Block-sparse linear solver policy.
BlockSparsePolicy
BlockSparsePolicy(n_actions: int, n: int | None = None, nz_values: Float[Array, N] | Variable[Float[Array, N]] | None = None, key: PRNGKeyArray | None = None, **kwargs)
Bases: AbstractBatchLinearSolverPolicy
Block-sparse linear solver policy.
This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:
These are stacked and stored as a single trainable parameter nz_values.
Initialize the block sparse policy.
Parameters:
-
(n_actionsint) –Number of actions to use.
-
(nint | None, default:None) –Number of rows and columns of the full operator. Must be provided if
nz_valuesis not provided. -
(nz_valuesFloat[Array, N] | Variable[Float[Array, N]] | None, default:None) –Non-zero values of the block-diagonal sparse matrix (shape
(n,)). If not provided, random actions are sampled using the key if provided. -
(keyPRNGKeyArray | None, default:None) –Random key for sampling actions if
nz_valuesis not provided. -
–**kwargsAdditional keyword arguments for
jax.random.normal(e.g.dtype)
Methods:
-
to_actions–Convert to block diagonal sparse action operators.
Attributes:
-
n_actions(int) –Number of actions to be used.
Source code in src/cagpjax/policies/block_sparse.py
to_actions
Convert to block diagonal sparse action operators.
Parameters:
-
(ALinearOperator) –Linear operator (unused).
Returns:
-
BlockDiagonalSparse(LinearOperator) –Sparse action structure representing the blocks.