Skip to content

cagpjax.policies.pseudoinput

Pseodo-input linear solver policy.

PseudoInputPolicy

Bases: AbstractBatchLinearSolverPolicy

Pseudo-input linear solver policy.

This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.

Parameters:

Name Type Description Default
pseudo_inputs Float[Array, 'M D'] | Variable

Pseudo-inputs for the kernel. If wrapped as a gpjax.parameters.Parameter, they will be treated as trainable.

required
train_inputs

Training inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.

required
kernel AbstractKernel

Kernel for the GP prior. It must be able to take train_inputs and pseudo_inputs as arguments to its cross_covariance method.

required
Source code in src/cagpjax/policies/pseudoinput.py
class PseudoInputPolicy(AbstractBatchLinearSolverPolicy):
    """Pseudo-input linear solver policy.

    This policy constructs actions from the cross-covariance between the training inputs and
    pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to
    inducing points and can be marked as trainable.

    Args:
        pseudo_inputs: Pseudo-inputs for the kernel. If wrapped as a `gpjax.parameters.Parameter`,
            they will be treated as trainable.
        train_inputs: Training inputs or a dataset containing training inputs. These must be the
            same inputs in the same order as the training data used to condition the CaGP model.
        kernel: Kernel for the GP prior. It must be able to take `train_inputs` and `pseudo_inputs`
            as arguments to its `cross_covariance` method.
    """

    pseudo_inputs: nnx.Variable
    train_inputs: Float[Array, "N D"]
    kernel: gpjax.kernels.AbstractKernel

    def __init__(
        self,
        pseudo_inputs: Float[Array, "M D"] | nnx.Variable,
        train_inputs_or_dataset: Float[Array, "N D"] | gpjax.dataset.Dataset,
        kernel: gpjax.kernels.AbstractKernel,
    ):
        if isinstance(train_inputs_or_dataset, gpjax.dataset.Dataset):
            train_data = train_inputs_or_dataset
            if train_data.X is None:
                raise ValueError("Dataset must contain training inputs.")
            train_inputs = train_data.X
        else:
            train_inputs = train_inputs_or_dataset
        if not isinstance(pseudo_inputs, nnx.Variable):
            pseudo_inputs = gpjax.parameters.Static(jnp.atleast_2d(pseudo_inputs))
        self.pseudo_inputs = pseudo_inputs
        self.train_inputs = jnp.atleast_2d(train_inputs)
        self.kernel = kernel

    @property
    def n_actions(self):
        return self.pseudo_inputs.shape[0]

    def to_actions(self, A: LinearOperator) -> LinearOperator:
        S = self.kernel.cross_covariance(self.train_inputs, self.pseudo_inputs.value)
        return cola.lazify(S)