Skip to content

cagpjax.operators.diag_like

diag_like(operator, values)

Create a diagonal operator with the same shape, dtype, and device as the operator.

Parameters:

Name Type Description Default
operator LinearOperator

Linear operator.

required
values ScalarFloat | Float[Array, N]

Scalar for a scalar matrix or array of diagonal values for a diagonal matrix.

required

Returns:

Type Description
Diagonal | ScalarMul

Diagonal or scalar operator.

Source code in src/cagpjax/operators/diag_like.py
def diag_like(
    operator: LinearOperator, values: ScalarFloat | Float[Array, "N"]
) -> Diagonal | ScalarMul:
    """Create a diagonal operator with the same shape, dtype, and device as the operator.

    Args:
        operator: Linear operator.
        values: Scalar for a scalar matrix or array of diagonal values for a diagonal matrix.

    Returns:
            Diagonal or scalar operator.
    """
    device = operator.device
    dtype = operator.dtype
    if jnp.isscalar(values):
        return ScalarMul(values, operator.shape, dtype=dtype, device=device)
    else:
        return Diagonal(values.astype(dtype)).to(device)