GPJax icon indicating copy to clipboard operation
GPJax copied to clipboard

bug: possible bug in handling kernels that are combinations of combinations

Open matthewrhysjones opened this issue 2 years ago • 3 comments

Bug Report

0.8.0

Current behavior:

This may be a problem with how I am interpreting how GPJax handles combination kernels, so sorry if I've missed something.

It seems that kernels which are a combination of a combination kernel are not being handled as expected when more than one type of combination operator is used (e.g the kernel is a sum of product kernels, or the kernel is a product of sum kernels). There doesn't appear to be a problem if both combination operators are identical (a sum of a sum kernel, or a product of product kernel).

Expected behavior:

When using a combination of combination kernel, predictive mean should be identical whether using GPJax or computing manually.

Steps to reproduce: see below

Related code:

xall = jnp.linspace(-5,5,1000)
toy_fun = lambda x: 1/5*x**2 + jnp.sin(x*5)**3 + jnp.cos(x*3)**2

xtrain = xall[0::25][:, None]
ytrain = toy_fun(xtrain)
xtest = xall[:, None]
ytest = toy_fun(xtest)

D = gpx.gps.Dataset(xtrain, ytrain)

kernel1 = gpx.kernels.RBF()
kernel2 = gpx.kernels.Matern32()
sum_kernel = kernel1 + kernel2

# using GPJax
pos_kernel = sum_kernel * sum_kernel # pos = product of sum

pos_prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel = pos_kernel)
pos_posterior = pos_prior * gpx.gps.Gaussian(D.n)

latent_dist_pos = pos_posterior.likelihood(pos_posterior(xtest, train_data=D))
mu_pos = latent_dist_pos.mean()
std_pos = latent_dist_pos.stddev()

# manual calculation of predictive dist

kxx = (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense()) * (kernel1.gram(xtrain).to_dense() + kernel2.gram(xtrain).to_dense())
kxt = (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest)) * (kernel1.cross_covariance(xtrain,xtest) + kernel2.cross_covariance(xtrain,xtest))
ktt = (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense()) * (kernel1.gram(xtest).to_dense() + kernel2.gram(xtest).to_dense())

L = jnp.linalg.cholesky(kxx + 1*jnp.eye(D.n))  #1 here is to match the obs noise as assigned in the GPJax likelihood
alpha = jnp.linalg.solve(L.T,jnp.linalg.solve(L,ytrain))
v = jnp.linalg.solve(L,kxt)

mu_manual_pos = kxt.T @ alpha
cov_manual_pos = ktt - v.T @ v
var_manual_pos = jnp.diag(cov_manual_pos) +1 # adding obs variance to match GPJax stddev output

plt.plot(xtest,mu_manual_pos,':')
plt.plot(xtest,mu_pos,'--')

there is a discrepancy between "mu_manual_pos" and "mu_pos" when I don't believe there should be. Also true if we use a kernel that is a sum of individual product kernels. However, if the combination operators are identical (sum of sum, product of products), then the results become the same, and so it appears there is some problem with the way that GPJax is handling combinations of combinations that contain multiple operators.

Other information:

I found this issue when I've been working with kernels that are combinations of combinations for a personal project, where I am seeing drastic differences between using GPJax and manual computation. I've tried to simplify the problem for this post to make it as clear as possible.

matthewrhysjones avatar Dec 27 '23 23:12 matthewrhysjones

Hey Matthew, I just ran into the same problem. I think the issue is in the post_init of the Combination kernel class.

    def __post_init__(self):
        # Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
        kernels_list: List[AbstractKernel] = []

        for kernel in self.kernels:
            if not isinstance(kernel, AbstractKernel):
                raise TypeError("can only combine Kernel instances")  # pragma: no cover

            if isinstance(kernel, self.__class__):
                kernels_list.extend(kernel.kernels)
            else:
                kernels_list.append(kernel)

        self.kernels = kernels_list

Here it calculates a flattened list of kernels, and saves it to the the kernels attribute. When the kernel is called, it returns the operation of the kernel across all kernels in the kernel list

  return self.operator(jnp.stack([k(x, y) for k in self.kernels]))

So the structure of operations of kernels is lost, it blindly applies the current operation (e.g. sum) for all sub-kernels. This explains why the results are consistent if all kernel operations are the same.

I assume the easy fix would be to have two attributes, self.kernels and self.flattened_kernels

ChrisBoettner avatar Dec 30 '23 08:12 ChrisBoettner

This is indeed a bug, thank you for spotting it !

I don't think we need to have a separate flattened_kernels; I would either a) change SumKernel and ProductKernel to be actual subclasses of CombinationKernel (in which case the test on self.__class__ would only allow combining when the operation matches), or b) explicitly add an additional check that self.operator is kernel.operator.

Personally I'd prefer a) ... @thomaspinder @daniel-dodd ?

st-- avatar Jan 25 '24 21:01 st--

This issue has been marked as stale because it has been open for 7 days with no activity.

github-actions[bot] avatar Sep 01 '24 02:09 github-actions[bot]

There has been no activity on this PR for some time. Therefore, we will be automatically closing the PR if no new activity occurs within the next seven days. Thank you for your contributions.

github-actions[bot] avatar Feb 14 '25 08:02 github-actions[bot]