ao icon indicating copy to clipboard operation
ao copied to clipboard

[BUG] Float8Linear does not work with torch.inference_mode

Open leeeizhang opened this issue 1 year ago • 6 comments

FP8 Linear does not work for me:

  • torch == 2.4.0 + cu121
  • torchao == 0.4.0
  • cuda_arch == 8.9 (nvidia L40)
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

class FFN(nn.Module):
    def __init__(self, in_feature, hidden_feature, bias=True):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, hidden_feature, bias)
        self.fc2 = nn.Linear(hidden_feature, in_feature, bias)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x


bs, seq, dim = 32, 512, 1024

m = FFN(dim, dim * 4).cuda()
convert_to_float8_training(m)
# m = torch.compile(m)

x = torch.randn((bs, seq, dim), device="cuda")

with torch.inference_mode(mode=True):
    y = m(x)
/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
  File "/root/erdos/ops/triton/t.py", line 28, in <module>
    y = m(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/erdos/ops/triton/t.py", line 14, in forward
    x = self.fc1(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
    output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 59, in forward
    input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1])
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_tensor.py", line 360, in __torch_dispatch__
    raise NotImplementedError(f"attempting to run {func}, this is not supported")
NotImplementedError: attempting to run aten.reshape.default, this is not supported

leeeizhang avatar Aug 09 '24 08:08 leeeizhang

It seems like FP8Linear could not run on inference mode. I have removed the torch.inference_mode(), but it still not works:

/usr/local/lib/python3.10/dist-packages/torchao/ops.py:12: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  return torch.library.impl_abstract(f"{name}")(func)
Traceback (most recent call last):
  File "/root/erdos/ops/triton/t.py", line 28, in <module>
    y = m(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/erdos/ops/triton/t.py", line 14, in forward
    x = self.fc1(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 363, in forward
    output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torchao/float8/float8_linear.py", line 60, in forward
    res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t)
RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],
        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],
        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],
        ...,
        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],
        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],
        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],
       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor

leeeizhang avatar Aug 09 '24 08:08 leeeizhang

cc @vkuzo, @drisspg

supriyar avatar Aug 09 '24 22:08 supriyar

Hey @leeeizhang I'm also facing the same issue below even with the latest changes and seems not related to reshaping or mode. Did this issue got resolved on your side after the fix? Thank you! (I'm using torch 2.3.1 btw). The returned tensor tuple is coming from here. ( emulate=True does not have problem btw) It seems for all the output type cases, it will return a tuple of two tensors (for fp16/fp32/bf16, it will return a tuple with second tensor to be 0 scalar and for fp8 output type it will return the scale ) For for all the cases, it cannot be treated as a single tensor and casted. 🤔 This is a poor man version fix: https://github.com/pytorch/ao/pull/702 (need a better solution)

RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],
        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],
        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],
        ...,
        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],
        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],
        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],
       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor

qingquansong avatar Aug 18 '24 22:08 qingquansong

Hey @leeeizhang I'm also facing the same issue below even with the latest changes and seems not related to reshaping or mode. Did this issue got resolved on your side after the fix? Thank you! (I'm using torch 2.3.1 btw). The returned tensor tuple is coming from here. ( emulate=True does not have problem btw) It seems for all the output type cases, it will return a tuple of two tensors (for fp16/fp32/bf16, it will return a tuple with second tensor to be 0 scalar and for fp8 output type it will return the scale ) For for all the cases, it cannot be treated as a single tensor and casted. 🤔 This is a poor man version fix: https://github.com/pytorch/ao/pull/702 (need a better solution)


RuntimeError: Unable to cast (tensor([[ 0.5145,  0.0516,  0.2376,  ...,  0.4223, -0.5424,  0.4097],

        [ 0.9510, -0.3045, -0.3444,  ..., -0.0056, -0.5410,  1.1299],

        [ 0.3378, -0.2371, -0.4324,  ..., -0.4220,  0.5146,  0.4283],

        ...,

        [ 0.5058,  0.1434,  0.4000,  ...,  0.0190,  0.5246,  0.2922],

        [-0.1201, -0.2883,  0.2411,  ...,  0.4197,  0.5214,  0.3386],

        [-0.4649, -0.6164,  0.3143,  ..., -0.3093, -0.0355,  0.4321]],

       device='cuda:0'), tensor(32.0000, device='cuda:0')) to Tensor

Try the torch nightly (2.5.0dev), which refactor the returns of torch._scaled_mm into tensor instead of tuple.

leeeizhang avatar Aug 19 '24 01:08 leeeizhang

@leeeizhang Thank you very much!

qingquansong avatar Aug 19 '24 03:08 qingquansong

Thanks for filing, I think we should make the version expectations clear in the readme, reopening until we make that happen.

vkuzo avatar Aug 20 '24 16:08 vkuzo