Compilation with LabelTensors fails in PyTorch 2.8 due to incorrect runtime_type inference
Describe the bug
After upgrading to PyTorch 2.8.0, compilation involving LabelTensor fails due to incorrect inference of runtime_type in the torch compilation pipeline. This results in a RuntimeError during the backward pass.
To Reproduce Upgrade to PyTorch 2.8.0 and run:
pytest tests/test_solver/test_pinn.py
You might also want to upgrade torchvision.
Expected behavior
All tests should pass successfully, and compilation should handle LabelTensor without metadata mismatches.
Output
RuntimeError:
E During the backward, we encountered a tensor subclass where we guessed its
E metadata incorrectly.
E
E Expected metadata: None, expected type: <class 'torch.Tensor'>
E
E Runtime metadata: None, runtime type: <class 'pina.label_tensor.LabelTensor'>
E
E shape: torch.Size([1, 1])
E To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
Additional context
- Implementing
__force_to_same_metadata__inLabelTensordid not resolve the issue. - This might be related to changes introduced in the compilation pipeline in PyTorch 2.8.
- The issue was not observed in PyTorch 2.7.x.
Maybe this can help, it seems the behaviour changed after this commit
@GiovanniCanali @FilippoOlivo I am investigating this. I don't think it is related to LabelTensor. I tried to check the tests for SupervisedSolver and they all pass (even when we set use_lt=True and compile=True). I don't know what is going on with physics-informed based solver, but I will try to investigate a bit more and come back to you
@GiovanniCanali @FilippoOlivo I am investigating this. I don't think it is related to
LabelTensor. I tried to check the tests forSupervisedSolverand they all pass (even when we setuse_lt=Trueandcompile=True). I don't know what is going on with physics-informed based solver, but I will try to investigate a bit more and come back to you
I think the main problem is with the operator module; something seems to break when using it in a compile mode. Since physics-based solvers are the only ones using these modules (we do not test compilation in test_operator) those tests fail.
Here is a minimal working code to reproduce the error:
import torch
from pina import LabelTensor
from pina.operator import grad
def func(input):
return input.pow(2).sum(-1, keepdim=True)
input_pts = torch.rand(10, 1)
input_pts = LabelTensor(input_pts, 'x')
input_pts.requires_grad_(True)
compiled_fn = torch.compile(func)
out = compiled_fn(input_pts)
out.labels = 'y'
grad_u = grad(out, input_pts)
loss = (grad_u - torch.zeros_like(grad_u)).pow(2).mean()
loss.sum().backward()
👋 I have news on this!
I have changed _scalar_grad as follow:
def _scalar_grad(output_, input_, d):
"""
Compute the gradient of a scalar-valued ``output_``.
:param LabelTensor output_: The output tensor on which the gradient is
computed. It must be a column tensor.
:param LabelTensor input_: The input tensor with respect to which the
gradient is computed.
:param list[str] d: The names of the input variables with respect to
which the gradient is computed. It must be a subset of the input
labels. If ``None``, all input variables are considered.
:return: The computed gradient tensor.
:rtype: LabelTensor
"""
grad_out = torch.autograd.grad(
outputs=output_.tensor, # <====== Change 1. Ensure we pass torch tensor
inputs=input_,
grad_outputs=torch.ones_like(output_.tensor, requires_grad=True), # <======= Change 2. Ensure we get torch tensor
create_graph=True,
retain_graph=True,
allow_unused=True,
)[0]
return grad_out[..., [input_.labels.index(i) for i in d]]
If we do so, it works for torch 2.7, but it does not for torch 2.8 with the following error:
RuntimeError: torch.compile with aot_autograd does not currently support double backward
This error is the same that we would get without LabelTensor, so I guess the LabelTensor problem can be easily fixed with this approach.
Regarding the above error, I investigated a little bit more at it seems there is an open issue on torch to solve this problem, introduced in the 2.8 version. Basically, every time we use aot_autograd, we can only do backprop once. This basically means that we can not train and, at the same time, use automatic differentiation (like in pinns).
I think the best way now is to wait that this issue is solved in Torch, possibly with the new release. I would still compile if the torch version is < 2.8. In case torch > 2.8, what I would do is to disable compilation in every solver inheriting from PINNInterface. @GiovanniCanali @FilippoOlivo What do you think?
Hi @dario-coscia, Thank you for looking into this issue further. The linked issue has been open since January 2023 and doesn’t seem to be actively maintained.
I’m a bit perplexed, though, since compilation works with PyTorch 2.7.x, which was released after the reported issue. This makes me unsure whether the linked issue is actually the root cause, so I’ll investigate further.
In the meantime, let’s disable compilation for PyTorch > 2.8 as a temporary workaround.
@dario-coscia Since #626 was a temporary fix, I would not mark the issue as closed.