Skip to content

cagpjax.operators.block_diagonal_sparse

Block-diagonal sparse linear operator.

Classes:

BlockDiagonalSparse

BlockDiagonalSparse(nz_values: Float[Array, N], n_blocks: int)

Bases: LinearOperator

Block-diagonal sparse linear operator.

This operator represents a block-diagonal matrix structure where the blocks are contiguous, and each contains a column vector, so that exactly one value is non-zero in each row.

Parameters:

  • nz_values

    (Float[Array, N]) –

    Non-zero values to be distributed across diagonal blocks.

  • n_blocks

    (int) –

    Number of diagonal blocks in the matrix.

Examples

>>> import jax.numpy as jnp
>>> from cagpjax.operators import BlockDiagonalSparse
>>>
>>> # Create a 3x6 block-diagonal matrix with 3 blocks
>>> nz_values = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
>>> op = BlockDiagonalSparse(nz_values, n_blocks=3)
>>> print(op.shape)
(6, 3)
>>>
>>> # Apply to identity matrices
>>> op @ jnp.eye(3)
Array([[1., 0., 0.],
       [2., 0., 0.],
       [0., 3., 0.],
       [0., 4., 0.],
       [0., 0., 5.],
       [0., 0., 6.]], dtype=float32)
Source code in src/cagpjax/operators/block_diagonal_sparse.py
def __init__(self, nz_values: Float[Array, "N"], n_blocks: int):
    n = nz_values.shape[0]
    super().__init__(nz_values.dtype, (n, n_blocks), annotations={ScaledOrthogonal})
    self.nz_values = nz_values