cagpjax.operators
Custom linear operators.
Modules:
-
annotations–Annotations for operators.
-
block_diagonal_sparse–Block-diagonal sparse linear operator.
-
diag_like– -
lazy_kernel–Lazy kernel operator
Classes:
-
BlockDiagonalSparse–Block-diagonal sparse linear operator.
-
LazyKernel–A lazy kernel operator that avoids materializing large kernel matrices.
BlockDiagonalSparse
Bases: LinearOperator
Block-diagonal sparse linear operator.
This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.
Parameters:
-
(nz_valuesFloat[Array, N]) –Non-zero values to be distributed across diagonal blocks.
-
(n_blocksint) –Number of diagonal blocks in the matrix.
Examples
>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
[2., 0., 0.],
[0., 3., 0.],
[0., 4., 0.],
[0., 0., 5.],
[0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
LazyKernel
LazyKernel(kernel: AbstractKernel, x1: Float[Array, 'M D'], x2: Float[Array, 'N D'], /, *, max_memory_mb: int = 2 ** 10, batch_size: int | None = None, checkpoint: bool = False)
Bases: LinearOperator
A lazy kernel operator that avoids materializing large kernel matrices.
This class implements a lazy kernel operator that computes rows/cols of a kernel matrix in blocks, preventing memory issues with large datasets.
Parameters:
-
(kernelAbstractKernel) –The kernel function to use for computations.
-
(x1Float[Array, 'M D']) –First set of input points for kernel evaluation.
-
(x2Float[Array, 'N D']) –Second set of input points for kernel evaluation.
-
(max_memory_mbint, default:2 ** 10) –Maximum number of megabytes of memory to allocate for batching the kernel matrix. If
batch_sizeis provided, this is ignored. -
(batch_sizeint | None, default:None) –Number of rows/cols to materialize at once. If
None, the batch size is determined based onmax_memory_mb. -
(checkpointbool, default:False) –Whether to checkpoint the computation. This is usually necessary to prevent all materialized submatrices from being retained in memory for gradient computation.
Attributes:
-
batch_size_col(int) –Maximum number of columns to materialize at once during left mat(-vec)muls.
-
batch_size_row(int) –Maximum number of rows to materialize at once during right mat(-vec)muls.
-
max_elements(int) –Maximum number of elements to store in memory during matmul operations.
Source code in src/cagpjax/operators/lazy_kernel.py
batch_size_col
property
Maximum number of columns to materialize at once during left mat(-vec)muls.
batch_size_row
property
Maximum number of rows to materialize at once during right mat(-vec)muls.