Skip to content

cagpjax.distributions

Classes:

  • GaussianDistribution

    Gaussian distribution with an implicit covariance and customizable linear solver.

GaussianDistribution

GaussianDistribution(loc: Float[Array, ' N'], scale: LinearOperator, solver: AbstractLinearSolver = Cholesky(1e-06), **kwargs)

Bases: Distribution

Gaussian distribution with an implicit covariance and customizable linear solver.

Initialize the Gaussian distribution.

Parameters:

  • loc

    (Float[Array, ' N']) –

    Mean of the distribution.

  • scale

    (LinearOperator) –

    Scale of the distribution.

  • solver

    (AbstractLinearSolver, default: Cholesky(1e-06) ) –

    Method for solving the linear system of equations.

Methods:

  • covariance

    Operator representing the covariance of the distribution.

  • log_prob

    Compute the log probability of the distribution at the given value.

  • sample

    Sample from the distribution.

Attributes:

  • mean (Float[Array, ' N']) –

    Mean of the distribution.

  • stddev (Float[Array, ' N']) –

    Marginal standard deviation of the distribution.

  • variance (Float[Array, ' N']) –

    Marginal variance of the distribution.

Source code in src/cagpjax/distributions.py
def __init__(
    self,
    loc: Float[Array, " N"],
    scale: LinearOperator,
    solver: AbstractLinearSolver = Cholesky(1e-6),
    **kwargs,
):
    """Initialize the Gaussian distribution.

    Args:
        loc: Mean of the distribution.
        scale: Scale of the distribution.
        solver: Method for solving the linear system of equations.
    """
    self.loc = loc
    self.scale = scale
    batch_shape = ()
    event_shape = jnp.shape(self.loc)
    self.solver = solver
    super().__init__(batch_shape, event_shape, **kwargs)

mean property

mean: Float[Array, ' N']

Mean of the distribution.

stddev property

stddev: Float[Array, ' N']

Marginal standard deviation of the distribution.

variance property

variance: Float[Array, ' N']

Marginal variance of the distribution.

covariance

covariance() -> LinearOperator

Operator representing the covariance of the distribution.

Source code in src/cagpjax/distributions.py
def covariance(self) -> LinearOperator:
    """Operator representing the covariance of the distribution."""
    return self.scale

log_prob

log_prob(value: Float[Array, ' N']) -> ScalarFloat

Compute the log probability of the distribution at the given value.

Parameters:

  • value

    (Float[Array, ' N']) –

    Value at which to compute the log probability.

Returns:

  • ScalarFloat

    Log probability of the distribution at the given value.

Source code in src/cagpjax/distributions.py
def log_prob(self, value: Float[Array, " N"]) -> ScalarFloat:  # pyright: ignore[reportIncompatibleMethodOverride]
    """Compute the log probability of the distribution at the given value.

    Args:
        value: Value at which to compute the log probability.

    Returns:
        Log probability of the distribution at the given value.
    """
    mu = self.loc
    sigma = self.scale
    n = mu.shape[-1]
    solver_state = self.solver.init(sigma)
    return (
        n * jnp.log(2 * jnp.pi)
        + self.solver.logdet(solver_state)
        + self.solver.inv_quad(solver_state, value - mu)
    ) / -2

sample

sample(key: Key, sample_shape: tuple[int, ...] = ()) -> Float[Array, '*sample_shape N']

Sample from the distribution.

Parameters:

  • key

    (Key) –

    Random key for sampling.

  • sample_shape

    (tuple[int, ...], default: () ) –

    Shape of the sample.

Returns:

  • Float[Array, '*sample_shape N']

    Sample from the distribution.

Source code in src/cagpjax/distributions.py
def sample(
    self,
    key: Key,
    sample_shape: tuple[int, ...] = (),
) -> Float[Array, "*sample_shape N"]:
    """Sample from the distribution.

    Args:
        key: Random key for sampling.
        sample_shape: Shape of the sample.

    Returns:
        Sample from the distribution.
    """
    mu = self.loc
    sigma = self.scale
    n = mu.shape[-1]
    solver_state = self.solver.init(sigma)
    z = jax.random.normal(key, (n, math.prod(sample_shape)), dtype=mu.dtype)
    x = self.solver.unwhiten(solver_state, z)
    return x.T.reshape(sample_shape + (n,)) + mu