Avoid evaluating kernel when adding jitter
The memory usage is much higher as a result of evaluating the kernel on this line.
Is there any reason for not using the parent class method to do this?
Related issues:
- https://github.com/cornellius-gp/gpytorch/issues/1826
- https://github.com/cornellius-gp/gpytorch/issues/1951
The test failure is when calling the predict method of the deep GP multitask example.
Possibly related to this issue https://github.com/cornellius-gp/gpytorch/issues/1892?
An example here that illustrates the issue clearly.
import torch
import gpytorch
from torch.profiler import profile, ProfilerActivity
from torch.utils.data import DataLoader, TensorDataset
N = 4567
M = 123
X = torch.randn(N, 5)
Y = torch.randn(N, )
Z = torch.randn(M, 5)
class VGP(gpytorch.models.ApproximateGP):
def __init__(self, kern, inducing_points):
variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.shape[0])
variational_strategy = gpytorch.variational.VariationalStrategy(
self,
inducing_points,
variational_distribution,
learn_inducing_locations=True,
jitter_val=1e-3,
)
super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = kern
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
model_ng = VGP(gpytorch.kernels.RBFKernel(), Z)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
mll = gpytorch.mlls.VariationalELBO(likelihood, model_ng, num_data=X.shape[0])
data = DataLoader(TensorDataset(X, Y), batch_size=X.shape[0])
variational_ngd_optimizer = gpytorch.optim.NGD(model_ng.variational_parameters(), num_data=X.size(0), lr=0.)
with profile(activities=[ProfilerActivity.CPU],
profile_memory=True, record_shapes=True) as prof:
for x, y in data:
variational_ngd_optimizer.zero_grad()
loss = -mll(model_ng(x), y)
loss.backward()
variational_ngd_optimizer.step()
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_memory_usage", row_limit=10))
Here you can see the [4567, 4567] tensor:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Input Shapes
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
aten::empty_strided 0.10% 347.000us 0.10% 347.000us 4.627us 94.85 Mb 94.85 Mb 75 [[], [], [], [], [], []]
aten::empty 0.08% 286.000us 0.08% 286.000us 4.847us 80.70 Mb 80.70 Mb 59 [[], [], [], [], [], []]
aten::mm 2.68% 9.576ms 2.68% 9.576ms 9.576ms 79.57 Mb 79.57 Mb 1 [[4567, 7], [7, 4567]]
aten::mul 2.70% 9.647ms 2.70% 9.647ms 9.647ms 79.57 Mb 79.57 Mb 1 [[4567, 4567], [4567, 4567]]
aten::where 2.22% 7.923ms 2.22% 7.934ms 7.934ms 79.57 Mb 79.57 Mb 1 [[4567, 4567], [4567, 4567], []]
aten::div 1.98% 7.086ms 2.00% 7.137ms 7.137ms 79.57 Mb 79.57 Mb 1 [[4567, 4567], []]
aten::ge 1.78% 6.355ms 1.79% 6.395ms 6.395ms 19.89 Mb 19.89 Mb 1 [[4567, 4567], []]
aten::mm 11.39% 40.754ms 11.40% 40.762ms 5.095ms 17.14 Mb 17.14 Mb 8 [[123, 123], [123, 4567]]
aten::resize_ 0.05% 191.000us 0.05% 191.000us 13.643us 9.26 Mb 9.26 Mb 14 [[0], [], []]
aten::add 2.87% 10.247ms 2.87% 10.247ms 2.562ms 8.57 Mb 8.57 Mb 4 [[123, 4567], [123, 4567], []]
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
Self CPU time total: 357.657ms
The same output with the change from this PR:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg CPU Mem Self CPU Mem # of Calls Input Shapes
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
aten::mm 8.24% 33.905ms 8.24% 33.914ms 4.239ms 17.14 Mb 17.14 Mb 8 [[123, 123], [123, 4567]]
aten::empty_strided 0.06% 237.000us 0.06% 237.000us 3.703us 15.25 Mb 15.25 Mb 64 [[], [], [], [], [], []]
aten::resize_ 0.04% 169.000us 0.04% 169.000us 12.071us 9.26 Mb 9.26 Mb 14 [[0], [], []]
aten::add 3.52% 14.471ms 3.52% 14.471ms 3.618ms 8.57 Mb 8.57 Mb 4 [[123, 4567], [123, 4567], []]
aten::mul 1.85% 7.617ms 1.85% 7.617ms 2.539ms 6.43 Mb 6.43 Mb 3 [[4567, 123], [4567, 123]]
aten::mul 4.15% 17.091ms 4.15% 17.091ms 5.697ms 6.43 Mb 6.43 Mb 3 [[123, 4567], [123, 4567]]
aten::mul 1.56% 6.421ms 1.56% 6.421ms 3.211ms 4.29 Mb 4.29 Mb 2 [[123, 1], [123, 4567]]
aten::add 1.69% 6.938ms 1.70% 6.995ms 3.498ms 4.29 Mb 4.29 Mb 2 [[123, 4567], [], []]
aten::mm 0.73% 2.998ms 0.73% 2.999ms 2.999ms 2.14 Mb 2.14 Mb 1 [[123, 7], [7, 4567]]
aten::mm 0.80% 3.290ms 0.80% 3.291ms 3.291ms 2.14 Mb 2.14 Mb 1 [[4567, 1], [1, 123]]
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ---------------------------------------------
Self CPU time total: 411.449ms
This would be good to fix... there's been lots of open issues about this. I'll take a look next week.
@gpleiss were you able to take a look at this?
Sorry - hoping to get to this later this week
I fixed the problem (it's quite deep in the code). It requires a PR to the LinearOperator repo (https://github.com/cornellius-gp/linear_operator/pull/26). Once that PR is merged in, I'll open a new PR that successfully gets rid of the evaluate_kernel line.
Thanks @gpleiss!