GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

feat: Shape checking.

Open daniel-dodd opened this issue 3 years ago • 0 comments

Feature Request

Desirable to check parameter shapes.

For example, consider an ARD RBF kernel on two dimensions. Comparing two dictionaries params_correct where the length-scale parameter has the correct shape of (num_dims,) and params_incorrect where the length-scale parameter has the incorrect shape of (num_dims, 1):

import gpjax as gpx
import jax.numpy as jnp

kernel = gpx.RBF(active_dims=[0, 1])

x1 = jnp.array([1, 2])
x2 = jnp.array([4, 5])

params_correct = {'lengthscale': jnp.array([2., 7.]),
 'variance': jnp.array([1.])}

print(kernel(x1, x2, params_correct))


params_incorrect =  {'lengthscale': jnp.array([[2.], [7.]]),
 'variance': jnp.array([1.])}

print(kernel(x1, x2, params_incorrect))

gives differing results:

0.2961655369001007
0.08771402524732493

This is an easy mistake to make.

daniel-dodd avatar Aug 29 '22 15:08 daniel-dodd