GPJax
GPJax copied to clipboard
feat: Shape checking.
Feature Request
Desirable to check parameter shapes.
For example, consider an ARD RBF kernel on two dimensions. Comparing two dictionaries params_correct where the length-scale parameter has the correct shape of (num_dims,) and params_incorrect where the length-scale parameter has the incorrect shape of (num_dims, 1):
import gpjax as gpx
import jax.numpy as jnp
kernel = gpx.RBF(active_dims=[0, 1])
x1 = jnp.array([1, 2])
x2 = jnp.array([4, 5])
params_correct = {'lengthscale': jnp.array([2., 7.]),
'variance': jnp.array([1.])}
print(kernel(x1, x2, params_correct))
params_incorrect = {'lengthscale': jnp.array([[2.], [7.]]),
'variance': jnp.array([1.])}
print(kernel(x1, x2, params_incorrect))
gives differing results:
0.2961655369001007
0.08771402524732493
This is an easy mistake to make.