cagpjax.policies.pseudoinput
Pseodo-input linear solver policy.
Classes:
-
PseudoInputPolicy–Pseudo-input linear solver policy.
PseudoInputPolicy
PseudoInputPolicy(pseudo_inputs: Float[Array, 'M D'] | Parameter[Float[Array, 'M D']], train_inputs_or_dataset: Float[Array, 'N D'] | Dataset, kernel: AbstractKernel)
Bases: AbstractBatchLinearSolverPolicy
Pseudo-input linear solver policy.
This policy constructs actions from the cross-covariance between the training inputs and pseudo-inputs in the same input space. These pseudo-inputs are conceptually similar to inducing points and can be marked as trainable.
Parameters:
-
(pseudo_inputsFloat[Array, 'M D'] | Parameter[Float[Array, 'M D']]) –Pseudo-inputs for the kernel. If wrapped as a
gpjax.parameters.Parameter, they will be treated as trainable. -
–train_inputsTraining inputs or a dataset containing training inputs. These must be the same inputs in the same order as the training data used to condition the CaGP model.
-
(kernelAbstractKernel) –Kernel for the GP prior. It must be able to take
train_inputsandpseudo_inputsas arguments to itscross_covariancemethod.
Note
When training with many pseudo-inputs, it is common for the cross-covariance matrix to
become poorly conditioned. Performance can be significantly improved by orthogonalizing
the actions using an OrthogonalizationPolicy.