Skip to content

cagpjax.solvers.base

Base classes for linear solvers and methods.

AbstractLinearSolver

Bases: Module

Base class for linear solvers.

These solvers are used to exactly or approximately solve the linear system \(Ax = b\) for \(x\), where \(A\) is a positive (semi-)definite (PSD) linear operator.

Solvers should always be constructed by a AbstractLinearSolverMethod.

Source code in src/cagpjax/solvers/base.py
class AbstractLinearSolver(nnx.Module):
    """
    Base class for linear solvers.

    These solvers are used to exactly or approximately solve the linear
    system $Ax = b$ for $x$, where $A$ is a positive (semi-)definite (PSD)
    linear operator.

    Solvers should always be constructed by a `AbstractLinearSolverMethod`.
    """

    @abstractmethod
    def solve(self, b: Float[Array, "N #K"]) -> Float[Array, "N #K"]:
        """Compute a solution to the linear system $Ax = b$.

        Arguments:
            b: Right-hand side of the linear system.
        """
        pass

    @abstractmethod
    def logdet(self) -> ScalarFloat:
        """Compute the logarithm of the (pseudo-)determinant of $A$."""
        pass

    @abstractmethod
    def inv_congruence_transform(
        self, B: LinearOperator | Float[Array, "N K"]
    ) -> LinearOperator | Float[Array, "K K"]:
        """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

        Arguments:
            B: Linear operator or array to be applied.

        Returns:
            Linear operator or array resulting from the congruence transform.
        """
        pass

    def inv_quad(self, b: Float[Array, "N #1"]) -> ScalarFloat:
        """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

        Arguments:
            b: Right-hand side of the linear system.
        """
        return self.inv_congruence_transform(b[:, None]).squeeze()

    @abstractmethod
    def trace_solve(self, B: Self) -> ScalarFloat:
        r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $B$.

        Arguments:
            B: An `AbstractLinearSolver` of the same type as `self` representing
                the PSD linear operator $B$.
        """
        pass

inv_congruence_transform(B) abstractmethod

Compute the inverse congruence transform \(B^T x\) for \(x\) in \(Ax = B\).

Parameters:

Name Type Description Default
B LinearOperator | Float[Array, 'N K']

Linear operator or array to be applied.

required

Returns:

Type Description
LinearOperator | Float[Array, 'K K']

Linear operator or array resulting from the congruence transform.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def inv_congruence_transform(
    self, B: LinearOperator | Float[Array, "N K"]
) -> LinearOperator | Float[Array, "K K"]:
    """Compute the inverse congruence transform $B^T x$ for $x$ in $Ax = B$.

    Arguments:
        B: Linear operator or array to be applied.

    Returns:
        Linear operator or array resulting from the congruence transform.
    """
    pass

inv_quad(b)

Compute the inverse quadratic form \(b^T x\), for \(x\) in \(Ax = b\).

Parameters:

Name Type Description Default
b Float[Array, N]

Right-hand side of the linear system.

required
Source code in src/cagpjax/solvers/base.py
def inv_quad(self, b: Float[Array, "N #1"]) -> ScalarFloat:
    """Compute the inverse quadratic form $b^T x$, for $x$ in $Ax = b$.

    Arguments:
        b: Right-hand side of the linear system.
    """
    return self.inv_congruence_transform(b[:, None]).squeeze()

logdet() abstractmethod

Compute the logarithm of the (pseudo-)determinant of \(A\).

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def logdet(self) -> ScalarFloat:
    """Compute the logarithm of the (pseudo-)determinant of $A$."""
    pass

solve(b) abstractmethod

Compute a solution to the linear system \(Ax = b\).

Parameters:

Name Type Description Default
b Float[Array, N]

Right-hand side of the linear system.

required
Source code in src/cagpjax/solvers/base.py
@abstractmethod
def solve(self, b: Float[Array, "N #K"]) -> Float[Array, "N #K"]:
    """Compute a solution to the linear system $Ax = b$.

    Arguments:
        b: Right-hand side of the linear system.
    """
    pass

trace_solve(B) abstractmethod

Compute \(\mathrm{trace}(X)\) in \(AX=B\) for PSD \(B\).

Parameters:

Name Type Description Default
B Self

An AbstractLinearSolver of the same type as self representing the PSD linear operator \(B\).

required
Source code in src/cagpjax/solvers/base.py
@abstractmethod
def trace_solve(self, B: Self) -> ScalarFloat:
    r"""Compute $\mathrm{trace}(X)$ in $AX=B$ for PSD $B$.

    Arguments:
        B: An `AbstractLinearSolver` of the same type as `self` representing
            the PSD linear operator $B$.
    """
    pass

AbstractLinearSolverMethod

Bases: Module

Base class for linear solver methods.

These methods are used to construct AbstractLinearSolver instances.

Source code in src/cagpjax/solvers/base.py
class AbstractLinearSolverMethod(nnx.Module):
    """
    Base class for linear solver methods.

    These methods are used to construct `AbstractLinearSolver` instances.
    """

    @abstractmethod
    def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
        """Construct a solver from the positive (semi-)definite linear operator."""
        pass

__call__(A) abstractmethod

Construct a solver from the positive (semi-)definite linear operator.

Source code in src/cagpjax/solvers/base.py
@abstractmethod
def __call__(self, A: LinearOperator) -> AbstractLinearSolver:
    """Construct a solver from the positive (semi-)definite linear operator."""
    pass