Skip to content

cagpjax.interop

Interop utilities between cola and lineax operators.

Classes:

Functions:

  • lazify

    Convert GPJax/Lineax/array inputs into cola operators.

  • to_lineax

    Convert existing operators into Lineax only where GPJax requires it.

ColaLinearOperator

ColaLinearOperator(operator: LinearOperator)

Bases: AbstractLinearOperator

Wrap a cola operator with a Lineax-compatible interface.

Source code in src/cagpjax/interop.py
def __init__(self, operator: LinearOperator):
    self.operator = operator

lazify

lazify(A: Any) -> LinearOperator

Convert GPJax/Lineax/array inputs into cola operators.

Source code in src/cagpjax/interop.py
def lazify(A: Any) -> LinearOperator:
    """Convert GPJax/Lineax/array inputs into cola operators."""
    if isinstance(A, LinearOperator):
        return A
    if isinstance(A, ColaLinearOperator):
        return A.operator
    if isinstance(A, lx.TaggedLinearOperator):
        op = lazify(A.operator)
        if lx.positive_semidefinite_tag in A.tags:
            return cola.PSD(op)
        return op
    if isinstance(A, lx.MatrixLinearOperator):
        return cola.ops.Dense(A.matrix)
    if isinstance(A, lx.DiagonalLinearOperator):
        return cola.ops.Diagonal(A.diagonal)
    if isinstance(A, lx.IdentityLinearOperator):
        metadata = jax.eval_shape(A.as_matrix)
        return cola.ops.Identity(metadata.shape, metadata.dtype)
    if isinstance(A, lx.AbstractLinearOperator):
        return cola.lazify(A.as_matrix())
    return cola.lazify(A)

to_lineax

to_lineax(A: Any) -> lx.AbstractLinearOperator

Convert existing operators into Lineax only where GPJax requires it.

Source code in src/cagpjax/interop.py
def to_lineax(A: Any) -> lx.AbstractLinearOperator:
    """Convert existing operators into Lineax only where GPJax requires it."""
    if isinstance(A, lx.AbstractLinearOperator):
        return A
    if isinstance(A, cola.ops.Diagonal):
        return lx.DiagonalLinearOperator(A.diag)
    if isinstance(A, cola.ops.Identity):
        metadata = jax.ShapeDtypeStruct((A.shape[1],), A.dtype)
        return lx.IdentityLinearOperator(metadata)
    if isinstance(A, cola.ops.LinearOperator):
        return ColaLinearOperator(A)
    return lx.MatrixLinearOperator(A)