torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Using Torchmetrics with `torch.compile` in Pytorch 2.0

Open ductai199x opened this issue 2 years ago • 3 comments

🐛 Bug

I have been trying to get torchmetrics to work smoothly with torch.compile in Pytorch 2.0. However, I got this warning after running through a few training steps:

torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: '<graph break in wrapped_func>' (pyt_tf2/lib/python3.9/site-packages/torchmetrics/metric.py:386)
   reasons:  self._update_count == 0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.

I've seen some explanation: https://github.com/lucidrains/vit-pytorch/issues/262

TorchDynamo will convert and cache the Python bytecode, and the compiled functions will be stored in the cache. When the next check finds that the function needs to be recompiled, the function will be recompiled and cached. However, if the number of recompilations exceeds the maximum value set (64), the function will no longer be cached or recompiled. As mentioned above, the loss calculation and post-processing parts of the object detection algorithm are also dynamically calculated, and these functions need to be recompiled every time.

I understand that this warning means that although the compilation is giving up, the metric should still work correctly. However, possible speed up can be gained by addressing this warning

To Reproduce

Train any NN model with pytorch 2.0's torch.compile, torchmetrics and pytorch-lightning

class DummyClassifierPLWrapper(LightningModule):
    def __init__(
        self,
        num_classes: int,
        num_filters=6,
        training_config: Dict[str, Any] = default_training_config,
    ):
        super().__init__()
        self.model = DummyClassifier(num_classes)
        self.lr = training_config["lr"]
        self.milestones = training_config["milestones"]
        self.decay_rate = training_config["decay_rate"]

        self.train_acc = MulticlassAccuracy(num_classes=num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=num_classes)
        self.test_acc = MulticlassAccuracy(num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        with torch.no_grad():
            self.train_acc.update(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    ........

Expected behavior

The training run shouldn't see a "graph break" warning.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): using pip, version 0.11.4
  • Python & PyTorch Version (e.g., 1.0): Python 3.9, Pytorch 2.0.1, Pytorch-Lightning 2.0.2
  • Any other relevant information such as OS (e.g., Linux): Ubuntu 22.04, Nvidia-driver 530.30.02, Cuda 12.1, Cudnn 8.9.1.23-1+cuda12.1

ductai199x avatar May 13 '23 13:05 ductai199x

Hi @ductai199x, thanks for reporting this issue. Would it be possible for you to provide a fully reproducible example to make the debugging process easier?

SkafteNicki avatar May 15 '23 12:05 SkafteNicki

Hi @SkafteNicki, I have a short reproducible example. It works fine on CPU, but fails on GPU!

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import MulticlassAccuracy


class BoringModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.Linear(20, 5)
        self.accuracy = MulticlassAccuracy(5)
    
    def forward(self, x, y):
        y_hat = self.ln(x)
        return self.accuracy(y_hat, y)
    

class BoringDataset(Dataset):
    def __init__(self):
        self.X = torch.rand(1000, 20)
        self.y = torch.randint(0, 5, (1000,))
        
    def __getitem__(self, idx):
        return {"x": self.X[idx], "y": self.y[idx]}
    
    def __len__(self):
        return len(self.X)

model = BoringModel().to("cuda")
model = torch.compile(model)

data = DataLoader(BoringDataset(), batch_size=32)

torch._dynamo.config.verbose=True

for _ in range(5):
    for batch in tqdm.tqdm(data):
        o = model(**{k: v.to("cuda") for k, v in batch.items()})

At the end of the first epoch, it fails in this method (from torchmetrics.metric.Metric._wrap_update [likely on self._update_count?]):

def _wrap_update(self, update: Callable) -> Callable:
        @functools.wraps(update)
        def wrapped_func(*args: Any, **kwargs: Any) -> None:
            self._computed = None
            self._update_count += 1
            with torch.set_grad_enabled(self._enable_grad):
                try:
                    update(*args, **kwargs)
                except RuntimeError as err:
                    if "Expected all tensors to be on" in str(err):
                        raise RuntimeError(
                            "Encountered different devices in metric calculation (see stacktrace for details)."
                            " This could be due to the metric class not being on the same device as input."
                            f" Instead of `metric={self.__class__.__name__}(...)` try to do"
                            f" `metric={self.__class__.__name__}(...).to(device)` where"
                            " device corresponds to the device of the input."
                        ) from err
                    raise err

            if self.compute_on_cpu:
                self._move_list_states_to_cpu()

        return wrapped_func

The suggested reason is:

   function: '_forward_reduce_state_update' (/usr/local/lib/python3.10/dist-packages/torchmetrics/metric.py:283)
   reasons:  tensor 'args[0]' size mismatch at index 0. expected 32, actual 8

The whole traceback is very long.

stancld avatar May 18 '23 18:05 stancld

Hi @stancld, thanks for the reproducibly example. After looking at the traceback: image It seems setattr is the sinner, meaning that self._update_count += 1 statement is causing the recompile. I found this recently opened issue https://github.com/pytorch/pytorch/issues/101168 that mentions that setattr is not yet supported. I therefore assume there is nothing we can do before dynamo has wider support.

SkafteNicki avatar May 22 '23 15:05 SkafteNicki