cagpjax.policies
AbstractBatchLinearSolverPolicy
Bases: AbstractLinearSolverPolicy
, ABC
Abstract base class for policies that product action matrices.
Source code in src/cagpjax/policies/base.py
n_actions
abstractmethod
property
Number of actions in this policy.
to_actions(A)
abstractmethod
Compute all actions used to solve the linear system \(Ax=b\).
For a matrix \(A\) with shape (n, n)
, the action matrix has shape
(n, n_actions)
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
LinearOperator
|
Linear operator representing the linear system. |
required |
Returns:
Type | Description |
---|---|
LinearOperator
|
Linear operator representing the action matrix. |
Source code in src/cagpjax/policies/base.py
AbstractLinearSolverPolicy
Bases: Module
Abstract base class for all linear solver policies.
Policies define actions used to solve a linear system \(A x = b\), where \(A\) is a square linear operator.
BlockSparsePolicy
Bases: AbstractBatchLinearSolverPolicy
Block-sparse linear solver policy.
This policy uses a fixed block-diagonal sparse structure to define independent learnable actions. The matrix has the following structure:
These are stacked and stored as a single trainable parameter nz_values
.
Source code in src/cagpjax/policies/block_sparse.py
n_actions
property
Number of actions to be used.
__init__(n_actions, n=None, nz_values=None, key=None, **kwargs)
Initialize the block sparse policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_actions
|
int
|
Number of actions to use. |
required |
n
|
int | None
|
Number of rows and columns of the full operator. Must be provided if |
None
|
nz_values
|
Float[Array, N] | Variable[Float[Array, N]] | None
|
Non-zero values of the block-diagonal sparse matrix (shape |
None
|
key
|
PRNGKeyArray | None
|
Random key for sampling actions if |
None
|
**kwargs
|
Additional keyword arguments for |
{}
|
Source code in src/cagpjax/policies/block_sparse.py
to_actions(A)
Convert to block diagonal sparse action operators.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
LinearOperator
|
Linear operator (unused). |
required |
Returns:
Type | Description |
---|---|
LinearOperator
|
Transposed[BlockDiagonalSparse]: Sparse action structure representing the blocks. |
Source code in src/cagpjax/policies/block_sparse.py
LanczosPolicy
Bases: AbstractBatchLinearSolverPolicy
Lanczos-based policy for eigenvalue decomposition approximation.
This policy uses the Lanczos algorithm to compute the top n_actions
eigenvectors
of the linear operator \(A\).
Attributes:
Name | Type | Description |
---|---|---|
n_actions |
int
|
Number of Lanczos vectors/actions to compute. |
key |
PRNGKeyArray | None
|
Random key for reproducible Lanczos iterations. |
Source code in src/cagpjax/policies/lanczos.py
__init__(n_actions, key=None)
Initialize the Lanczos policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_actions
|
int
|
Number of Lanczos vectors to compute. |
required |
key
|
PRNGKeyArray | None
|
Random key for initialization. |
None
|
Source code in src/cagpjax/policies/lanczos.py
to_actions(A)
Compute action matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
A
|
LinearOperator
|
Symmetric linear operator representing the linear system. |
required |
Returns:
Type | Description |
---|---|
LinearOperator
|
Linear operator containing the Lanczos vectors as columns. |