Skip to content

cagpjax.operators

Custom linear operators.

Modules:

Classes:

  • BlockDiagonalSparse

    Block-diagonal sparse linear operator.

  • LazyKernel

    A lazy kernel operator that avoids materializing large kernel matrices.

BlockDiagonalSparse

BlockDiagonalSparse(nz_values: Float[Array, N], n_blocks: int)

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.

Parameters:

  • nz_values

    (Float[Array, N]) –

    Non-zero values to be distributed across diagonal blocks.

  • n_blocks

    (int) –

    Number of diagonal blocks in the matrix.

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
       [2., 0., 0.],
       [0., 3., 0.],
       [0., 4., 0.],
       [0., 0., 5.],
       [0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
    n = nz_values.shape[0]
    super().__init__(nz_values.dtype, (n, n_blocks), annotations={ScaledOrthogonal})
    self.nz_values = nz_values

LazyKernel

LazyKernel(kernel: AbstractKernel, x1: Float[Array, 'M D'], x2: Float[Array, 'N D'], /, *, max_memory_mb: int = 2 ** 10, batch_size: int | None = None, checkpoint: bool = False)

Bases: LinearOperator

A lazy kernel operator that avoids materializing large kernel matrices.

This class implements a lazy kernel operator that computes rows/cols of a kernel matrix in blocks, preventing memory issues with large datasets.

Parameters:

  • kernel

    (AbstractKernel) –

    The kernel function to use for computations.

  • x1

    (Float[Array, 'M D']) –

    First set of input points for kernel evaluation.

  • x2

    (Float[Array, 'N D']) –

    Second set of input points for kernel evaluation.

  • max_memory_mb

    (int, default: 2 ** 10 ) –

    Maximum number of megabytes of memory to allocate for batching the kernel matrix. If batch_size is provided, this is ignored.

  • batch_size

    (int | None, default: None ) –

    Number of rows/cols to materialize at once. If None, the batch size is determined based on max_memory_mb.

  • checkpoint

    (bool, default: False ) –

    Whether to checkpoint the computation. This is usually necessary to prevent all materialized submatrices from being retained in memory for gradient computation.

Attributes:

  • batch_size_col (int) –

    Maximum number of columns to materialize at once during left mat(-vec)muls.

  • batch_size_row (int) –

    Maximum number of rows to materialize at once during right mat(-vec)muls.

  • max_elements (int) –

    Maximum number of elements to store in memory during matmul operations.

Source code in src/cagpjax/operators/lazy_kernel.py
def __init__(
    self,
    kernel: AbstractKernel,
    x1: Float[Array, "M D"],
    x2: Float[Array, "N D"],
    /,
    *,
    max_memory_mb: int = 2**10,  # 1GB
    batch_size: int | None = None,
    checkpoint: bool = False,
):
    shape = (x1.shape[0], x2.shape[0])
    dtype = kernel(x1[0, ...], x2[0, ...]).dtype
    super().__init__(dtype=dtype, shape=shape)
    self.kernel = kernel
    self.x1 = x1
    self.x2 = x2
    self._compute_engine = DenseKernelComputation()
    self.max_memory_mb = max_memory_mb
    self.batch_size = batch_size
    self.checkpoint = checkpoint

batch_size_col property

batch_size_col: int

Maximum number of columns to materialize at once during left mat(-vec)muls.

batch_size_row property

batch_size_row: int

Maximum number of rows to materialize at once during right mat(-vec)muls.

max_elements property

max_elements: int

Maximum number of elements to store in memory during matmul operations.