gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

Avoid evaluating kernel when adding jitter

Open dannyfriar opened this issue 3 years ago • 4 comments

The memory usage is much higher as a result of evaluating the kernel on this line.

Is there any reason for not using the parent class method to do this?

dannyfriar avatar Sep 16 '22 13:09 dannyfriar

Related issues:

  • https://github.com/cornellius-gp/gpytorch/issues/1826
  • https://github.com/cornellius-gp/gpytorch/issues/1951

dannyfriar avatar Sep 16 '22 13:09 dannyfriar

The test failure is when calling the predict method of the deep GP multitask example.

Possibly related to this issue https://github.com/cornellius-gp/gpytorch/issues/1892?

dannyfriar avatar Sep 18 '22 19:09 dannyfriar

An example here that illustrates the issue clearly.

import torch
import gpytorch
from torch.profiler import profile, ProfilerActivity
from torch.utils.data import DataLoader, TensorDataset

N = 4567
M = 123
X = torch.randn(N, 5)
Y = torch.randn(N, )
Z = torch.randn(M, 5)

class VGP(gpytorch.models.ApproximateGP):
    def __init__(self, kern, inducing_points):
        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.shape[0])
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True,
            jitter_val=1e-3,
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = kern

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


model_ng = VGP(gpytorch.kernels.RBFKernel(), Z)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
mll = gpytorch.mlls.VariationalELBO(likelihood, model_ng, num_data=X.shape[0])

data = DataLoader(TensorDataset(X, Y), batch_size=X.shape[0])

variational_ngd_optimizer = gpytorch.optim.NGD(model_ng.variational_parameters(), num_data=X.size(0), lr=0.)

with profile(activities=[ProfilerActivity.CPU],
        profile_memory=True, record_shapes=True) as prof:

    for x, y in data:
        variational_ngd_optimizer.zero_grad()

        loss = -mll(model_ng(x), y)
        loss.backward()
        variational_ngd_optimizer.step()

print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_memory_usage", row_limit=10))

Here you can see the [4567, 4567] tensor:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls                                   Input Shapes
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
                                    aten::empty_strided         0.10%     347.000us         0.10%     347.000us       4.627us      94.85 Mb      94.85 Mb            75                       [[], [], [], [], [], []]
                                            aten::empty         0.08%     286.000us         0.08%     286.000us       4.847us      80.70 Mb      80.70 Mb            59                       [[], [], [], [], [], []]
                                               aten::mm         2.68%       9.576ms         2.68%       9.576ms       9.576ms      79.57 Mb      79.57 Mb             1                         [[4567, 7], [7, 4567]]
                                              aten::mul         2.70%       9.647ms         2.70%       9.647ms       9.647ms      79.57 Mb      79.57 Mb             1                   [[4567, 4567], [4567, 4567]]
                                            aten::where         2.22%       7.923ms         2.22%       7.934ms       7.934ms      79.57 Mb      79.57 Mb             1               [[4567, 4567], [4567, 4567], []]
                                              aten::div         1.98%       7.086ms         2.00%       7.137ms       7.137ms      79.57 Mb      79.57 Mb             1                             [[4567, 4567], []]
                                               aten::ge         1.78%       6.355ms         1.79%       6.395ms       6.395ms      19.89 Mb      19.89 Mb             1                             [[4567, 4567], []]
                                               aten::mm        11.39%      40.754ms        11.40%      40.762ms       5.095ms      17.14 Mb      17.14 Mb             8                      [[123, 123], [123, 4567]]
                                          aten::resize_         0.05%     191.000us         0.05%     191.000us      13.643us       9.26 Mb       9.26 Mb            14                                  [[0], [], []]
                                              aten::add         2.87%      10.247ms         2.87%      10.247ms       2.562ms       8.57 Mb       8.57 Mb             4                 [[123, 4567], [123, 4567], []]
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
Self CPU time total: 357.657ms

The same output with the change from this PR:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls                                   Input Shapes
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
                                               aten::mm         8.24%      33.905ms         8.24%      33.914ms       4.239ms      17.14 Mb      17.14 Mb             8                      [[123, 123], [123, 4567]]
                                    aten::empty_strided         0.06%     237.000us         0.06%     237.000us       3.703us      15.25 Mb      15.25 Mb            64                       [[], [], [], [], [], []]
                                          aten::resize_         0.04%     169.000us         0.04%     169.000us      12.071us       9.26 Mb       9.26 Mb            14                                  [[0], [], []]
                                              aten::add         3.52%      14.471ms         3.52%      14.471ms       3.618ms       8.57 Mb       8.57 Mb             4                 [[123, 4567], [123, 4567], []]
                                              aten::mul         1.85%       7.617ms         1.85%       7.617ms       2.539ms       6.43 Mb       6.43 Mb             3                     [[4567, 123], [4567, 123]]
                                              aten::mul         4.15%      17.091ms         4.15%      17.091ms       5.697ms       6.43 Mb       6.43 Mb             3                     [[123, 4567], [123, 4567]]
                                              aten::mul         1.56%       6.421ms         1.56%       6.421ms       3.211ms       4.29 Mb       4.29 Mb             2                        [[123, 1], [123, 4567]]
                                              aten::add         1.69%       6.938ms         1.70%       6.995ms       3.498ms       4.29 Mb       4.29 Mb             2                          [[123, 4567], [], []]
                                               aten::mm         0.73%       2.998ms         0.73%       2.999ms       2.999ms       2.14 Mb       2.14 Mb             1                          [[123, 7], [7, 4567]]
                                               aten::mm         0.80%       3.290ms         0.80%       3.291ms       3.291ms       2.14 Mb       2.14 Mb             1                          [[4567, 1], [1, 123]]
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------
Self CPU time total: 411.449ms

dannyfriar avatar Sep 20 '22 09:09 dannyfriar

This would be good to fix... there's been lots of open issues about this. I'll take a look next week.

gpleiss avatar Sep 21 '22 21:09 gpleiss

@gpleiss were you able to take a look at this?

dannyfriar avatar Oct 04 '22 21:10 dannyfriar

Sorry - hoping to get to this later this week

gpleiss avatar Oct 10 '22 21:10 gpleiss

I fixed the problem (it's quite deep in the code). It requires a PR to the LinearOperator repo (https://github.com/cornellius-gp/linear_operator/pull/26). Once that PR is merged in, I'll open a new PR that successfully gets rid of the evaluate_kernel line.

gpleiss avatar Oct 18 '22 19:10 gpleiss

Thanks @gpleiss!

dannyfriar avatar Oct 22 '22 12:10 dannyfriar