Skip to content

cagpjax.policies.pseudoinput

Pseodo-input linear solver policy.

Classes:

PseudoInputPolicy

PseudoInputPolicy(pseudo_inputs: Float[Array, 'M D'] | Parameter[Float[Array, 'M D']], train_inputs_or_dataset: Float[Array, 'N D'] | Dataset, kernel: AbstractKernel)

Bases: AbstractBatchLinearSolverPolicy

Pseudo-input linear solver policy.

This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.

Parameters:

  • pseudo_inputs

    (Float[Array, 'M D'] | Parameter[Float[Array, 'M D']]) –

    Pseudo-inputs for the kernel. If wrapped as a gpjax.parameters.Parameter, they will be treated as trainable.

  • train_inputs

    Training inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.

  • kernel

    (AbstractKernel) –

    Kernel for the GP prior. It must be able to take train_inputs and pseudo_inputs as arguments to its cross_covariance method.

Note

When training with many pseudo-inputs, it is common for the cross-covariance matrix to become poorly conditioned. Performance can be significantly improved by orthogonalizing the actions using an OrthogonalizationPolicy.

Source code in src/cagpjax/policies/pseudoinput.py
def __init__(
    self,
    pseudo_inputs: Float[Array, "M D"] | Parameter[Float[Array, "M D"]],
    train_inputs_or_dataset: Float[Array, "N D"] | gpjax.dataset.Dataset,
    kernel: gpjax.kernels.AbstractKernel,
):
    if isinstance(train_inputs_or_dataset, gpjax.dataset.Dataset):
        train_data = train_inputs_or_dataset
        if train_data.X is None:
            raise ValueError("Dataset must contain training inputs.")
        train_inputs = train_data.X
    else:
        train_inputs = train_inputs_or_dataset
    self.pseudo_inputs = pseudo_inputs
    self.train_inputs = jnp.atleast_2d(train_inputs)
    self.kernel = kernel