Skip to content

cagpjax.policies.lanczos

Lanczos-based policies.

LanczosPolicy

Bases: AbstractBatchLinearSolverPolicy

Lanczos-based policy for eigenvalue decomposition approximation.

This policy uses the Lanczos algorithm to compute the top n_actions eigenvectors of the linear operator \(A\).

Attributes:

Name Type Description
n_actions int

Number of Lanczos vectors/actions to compute.

key PRNGKeyArray | None

Random key for reproducible Lanczos iterations.

Source code in src/cagpjax/policies/lanczos.py
class LanczosPolicy(AbstractBatchLinearSolverPolicy):
    """Lanczos-based policy for eigenvalue decomposition approximation.

    This policy uses the Lanczos algorithm to compute the top ``n_actions`` eigenvectors
    of the linear operator $A$.

    Attributes:
        n_actions: Number of Lanczos vectors/actions to compute.
        key: Random key for reproducible Lanczos iterations.
    """

    key: PRNGKeyArray | None
    grad_rtol: float | None

    def __init__(
        self,
        n_actions: int | None,
        key: PRNGKeyArray | None = None,
        grad_rtol: float | None = 0.0,
    ):
        """Initialize the Lanczos policy.

        Args:
            n_actions: Number of Lanczos vectors to compute.
            key: Random key for initialization.
            grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
                gradient computation for (almost-)degenerate matrices.
                If not provided, the default is 0.0.
                If None or negative, all eigenvalues are treated as distinct.
                (see [`cagpjax.linalg.eigh`][] for more details)
        """
        self._n_actions: int = n_actions
        self.key = key
        self.grad_rtol = grad_rtol

    @property
    @override
    def n_actions(self) -> int:
        return self._n_actions

    @override
    def to_actions(self, A: LinearOperator) -> LinearOperator:
        """Compute action matrix.

        Args:
            A: Symmetric linear operator representing the linear system.

        Returns:
            Linear operator containing the Lanczos vectors as columns.
        """
        vecs = eigh(
            A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
        ).eigenvectors
        return vecs

__init__(n_actions, key=None, grad_rtol=0.0)

Initialize the Lanczos policy.

Parameters:

Name Type Description Default
n_actions int | None

Number of Lanczos vectors to compute.

required
key PRNGKeyArray | None

Random key for initialization.

None
grad_rtol float | None

Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct. (see cagpjax.linalg.eigh for more details)

0.0
Source code in src/cagpjax/policies/lanczos.py
def __init__(
    self,
    n_actions: int | None,
    key: PRNGKeyArray | None = None,
    grad_rtol: float | None = 0.0,
):
    """Initialize the Lanczos policy.

    Args:
        n_actions: Number of Lanczos vectors to compute.
        key: Random key for initialization.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.
            (see [`cagpjax.linalg.eigh`][] for more details)
    """
    self._n_actions: int = n_actions
    self.key = key
    self.grad_rtol = grad_rtol

to_actions(A)

Compute action matrix.

Parameters:

Name Type Description Default
A LinearOperator

Symmetric linear operator representing the linear system.

required

Returns:

Type Description
LinearOperator

Linear operator containing the Lanczos vectors as columns.

Source code in src/cagpjax/policies/lanczos.py
@override
def to_actions(self, A: LinearOperator) -> LinearOperator:
    """Compute action matrix.

    Args:
        A: Symmetric linear operator representing the linear system.

    Returns:
        Linear operator containing the Lanczos vectors as columns.
    """
    vecs = eigh(
        A, alg=Lanczos(self.n_actions, key=self.key), grad_rtol=self.grad_rtol
    ).eigenvectors
    return vecs