Using Torchmetrics with `torch.compile` in Pytorch 2.0
🐛 Bug
I have been trying to get torchmetrics to work smoothly with torch.compile in Pytorch 2.0. However, I got this warning after running through a few training steps:
torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in wrapped_func>' (pyt_tf2/lib/python3.9/site-packages/torchmetrics/metric.py:386)
reasons: self._update_count == 0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
I've seen some explanation: https://github.com/lucidrains/vit-pytorch/issues/262
TorchDynamo will convert and cache the Python bytecode, and the compiled functions will be stored in the cache. When the next check finds that the function needs to be recompiled, the function will be recompiled and cached. However, if the number of recompilations exceeds the maximum value set (64), the function will no longer be cached or recompiled. As mentioned above, the loss calculation and post-processing parts of the object detection algorithm are also dynamically calculated, and these functions need to be recompiled every time.
I understand that this warning means that although the compilation is giving up, the metric should still work correctly. However, possible speed up can be gained by addressing this warning
To Reproduce
Train any NN model with pytorch 2.0's torch.compile, torchmetrics and pytorch-lightning
class DummyClassifierPLWrapper(LightningModule):
def __init__(
self,
num_classes: int,
num_filters=6,
training_config: Dict[str, Any] = default_training_config,
):
super().__init__()
self.model = DummyClassifier(num_classes)
self.lr = training_config["lr"]
self.milestones = training_config["milestones"]
self.decay_rate = training_config["decay_rate"]
self.train_acc = MulticlassAccuracy(num_classes=num_classes)
self.val_acc = MulticlassAccuracy(num_classes=num_classes)
self.test_acc = MulticlassAccuracy(num_classes=num_classes)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
with torch.no_grad():
self.train_acc.update(y_hat, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True)
return loss
........
Expected behavior
The training run shouldn't see a "graph break" warning.
Environment
- TorchMetrics version (and how you installed TM, e.g.
conda,pip, build from source): usingpip, version 0.11.4 - Python & PyTorch Version (e.g., 1.0): Python 3.9, Pytorch 2.0.1, Pytorch-Lightning 2.0.2
- Any other relevant information such as OS (e.g., Linux): Ubuntu 22.04, Nvidia-driver 530.30.02, Cuda 12.1, Cudnn 8.9.1.23-1+cuda12.1
Hi @ductai199x, thanks for reporting this issue. Would it be possible for you to provide a fully reproducible example to make the debugging process easier?
Hi @SkafteNicki, I have a short reproducible example. It works fine on CPU, but fails on GPU!
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import MulticlassAccuracy
class BoringModel(nn.Module):
def __init__(self):
super().__init__()
self.ln = nn.Linear(20, 5)
self.accuracy = MulticlassAccuracy(5)
def forward(self, x, y):
y_hat = self.ln(x)
return self.accuracy(y_hat, y)
class BoringDataset(Dataset):
def __init__(self):
self.X = torch.rand(1000, 20)
self.y = torch.randint(0, 5, (1000,))
def __getitem__(self, idx):
return {"x": self.X[idx], "y": self.y[idx]}
def __len__(self):
return len(self.X)
model = BoringModel().to("cuda")
model = torch.compile(model)
data = DataLoader(BoringDataset(), batch_size=32)
torch._dynamo.config.verbose=True
for _ in range(5):
for batch in tqdm.tqdm(data):
o = model(**{k: v.to("cuda") for k, v in batch.items()})
At the end of the first epoch, it fails in this method (from torchmetrics.metric.Metric._wrap_update [likely on self._update_count?]):
def _wrap_update(self, update: Callable) -> Callable:
@functools.wraps(update)
def wrapped_func(*args: Any, **kwargs: Any) -> None:
self._computed = None
self._update_count += 1
with torch.set_grad_enabled(self._enable_grad):
try:
update(*args, **kwargs)
except RuntimeError as err:
if "Expected all tensors to be on" in str(err):
raise RuntimeError(
"Encountered different devices in metric calculation (see stacktrace for details)."
" This could be due to the metric class not being on the same device as input."
f" Instead of `metric={self.__class__.__name__}(...)` try to do"
f" `metric={self.__class__.__name__}(...).to(device)` where"
" device corresponds to the device of the input."
) from err
raise err
if self.compute_on_cpu:
self._move_list_states_to_cpu()
return wrapped_func
The suggested reason is:
function: '_forward_reduce_state_update' (/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py:283)
reasons: tensor 'args[0]' size mismatch at index 0. expected 32, actual 8
The whole traceback is very long.
Hi @stancld, thanks for the reproducibly example.
After looking at the traceback:
It seems
setattr is the sinner, meaning that self._update_count += 1 statement is causing the recompile.
I found this recently opened issue https://github.com/pytorch/pytorch/issues/101168 that mentions that setattr is not yet supported.
I therefore assume there is nothing we can do before dynamo has wider support.