Skip to content

cagpjax.linalg.eigh

Hermitian eigenvalue decomposition.

Classes:

  • Eigh

    Eigh algorithm for eigenvalue decomposition.

  • EighResult

    Result of Hermitian eigenvalue decomposition.

  • Lanczos

    Lanczos algorithm for approximate partial eigenvalue decomposition.

Functions:

  • eigh

    Compute the Hermitian eigenvalue decomposition of a linear operator.

Eigh

Bases: Algorithm

Eigh algorithm for eigenvalue decomposition.

EighResult

Bases: NamedTuple

Result of Hermitian eigenvalue decomposition.

Attributes:

  • eigenvalues (Float[Array, N]) –

    Eigenvalues of the operator.

  • eigenvectors (LinearOperator) –

    Eigenvectors of the operator.

Lanczos

Lanczos(max_iters: int | None = None, /, *, v0: Float[Array, N] | None = None, key: PRNGKeyArray | None = None)

Bases: Algorithm

Lanczos algorithm for approximate partial eigenvalue decomposition.

Parameters:

  • max_iters

    (int | None, default: None ) –

    Maximum number of iterations (number of eigenvalues/vectors to compute). If None, all eigenvalues/eigenvectors are computed.

  • v0

    (Float[Array, N] | None, default: None ) –

    Initial vector. If None, a random vector is generated using key.

  • key

    (PRNGKeyArray | None, default: None ) –

    Random key for generating a random initial vector if v0 is not provided.

Source code in src/cagpjax/linalg/eigh.py
def __init__(
    self,
    max_iters: int | None = None,
    /,
    *,
    v0: Float[Array, "N"] | None = None,
    key: PRNGKeyArray | None = None,
):
    self.max_iters = max_iters
    self.v0 = v0
    self.key = key

eigh

eigh(A: LinearOperator, alg: Algorithm = Eigh(), grad_rtol: float | None = None) -> EighResult

Compute the Hermitian eigenvalue decomposition of a linear operator.

For some algorithms, the decomposition may be approximate or partial.

Parameters:

  • A

    (LinearOperator) –

    Hermitian linear operator.

  • alg

    (Algorithm, default: Eigh() ) –

    Algorithm for eigenvalue decomposition.

  • grad_rtol

    (float | None, default: None ) –

    Specifies the cutoff for similar eigenvalues, used to improve gradient computation for (almost-)degenerate matrices. If not provided, the default is 0.0. If None or negative, all eigenvalues are treated as distinct.

Returns:

  • EighResult

    A named tuple of (eigenvalues, eigenvectors) where eigenvectors is a (semi-)orthogonal LinearOperator.

Note

Degenerate matrices have repeated eigenvalues. The set of eigenvectors that correspond to the same eigenvalue is not unique but instead forms a subspace. grad_rtol only improves stability of gradient-computation if the function being differentiated depends only depends on these subspaces and not the specific eigenvectors themselves.

Source code in src/cagpjax/linalg/eigh.py
def eigh(
    A: LinearOperator,
    alg: cola.linalg.Algorithm = Eigh(),
    grad_rtol: float | None = None,
) -> EighResult:
    """Compute the Hermitian eigenvalue decomposition of a linear operator.

    For some algorithms, the decomposition may be approximate or partial.

    Args:
        A: Hermitian linear operator.
        alg: Algorithm for eigenvalue decomposition.
        grad_rtol: Specifies the cutoff for similar eigenvalues, used to improve
            gradient computation for (almost-)degenerate matrices.
            If not provided, the default is 0.0.
            If None or negative, all eigenvalues are treated as distinct.

    Returns:
        A named tuple of `(eigenvalues, eigenvectors)` where `eigenvectors` is a
            (semi-)orthogonal `LinearOperator`.

    Note:
        Degenerate matrices have repeated eigenvalues.
        The set of eigenvectors that correspond to the same eigenvalue is not unique
        but instead forms a subspace.
        `grad_rtol` only improves stability of gradient-computation if the function
        being differentiated depends only depends on these subspaces and not the
        specific eigenvectors themselves.
    """
    if grad_rtol is None:
        grad_rtol = -1.0
    vals, vecs = _eigh(A, alg, grad_rtol)  # pyright: ignore[reportArgumentType]
    if vecs.shape[-1] == A.shape[-1]:
        vecs = cola.Unitary(vecs)
    else:
        vecs = cola.Stiefel(vecs)
    return EighResult(vals, vecs)