ao icon indicating copy to clipboard operation
ao copied to clipboard

[ROCm] float8 does not work

Open functionstackx opened this issue 1 year ago • 1 comments

Hi @hongxiayang @hliuca ,

It seems like float8 training using torchao.float8 is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8?

Attempting Install From Nightly

From using the ROCm nightly torchao wheel, the torchao.float8 module is not present

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/rocm6.2
python -c "import torchao; print(dir(torchao))"
['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'apply_dynamic_quant', 'apply_weight_only_int8_quant', 'dtypes', 'kernel', 'quantization']

Attempting Install From Source

From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install git+https://github.com/pytorch/ao.git

Eager Mode Error

   tensor_out = addmm_float8_unwrapped(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
    output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.

Compile Mode Error

    tmp15 = 448.0
    tmp16 = triton_helpers.minimum(tmp14, tmp15)
    tmp17 = tmp16.to(tl.float8e4nv)
            ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Reprod Script is From The torchao.float8 README Example

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()

functionstackx avatar Oct 12 '24 23:10 functionstackx

Can you make sure to set: https://github.com/pytorch/ao/blob/e7b33bc91c831d10249c1222c8b4b667f18f28b7/torchao/float8/config.py#L246 to True

drisspg avatar Oct 14 '24 19:10 drisspg

Here are my thoughts on what we need to do to enable ROCm support for float8:

  1. ensure torch._scaled_mm's path for ROCm is fast an accurate
  2. there is a config setting for the nuz float8 flavors here (https://github.com/pytorch/ao/blob/0b71b8d38f6b238e510876bf2d75b4280a651175/torchao/float8/config.py#L246), but it's not tested at the moment. We should enable testing across all of our test suite, first locally and then in CI.
  3. get e2e performance/accuracy to be good, measured by benchmarks on real workloads

vkuzo avatar Oct 16 '24 15:10 vkuzo

#1142 might help partially fix this by automating the setting of these types depending on the hardware deteted.

jeffdaily avatar Oct 22 '24 22:10 jeffdaily

@OrenLeung please install latest torchao and try again. #1142 was reverted but #1150 (reland) merged.

jeffdaily avatar Oct 24 '24 05:10 jeffdaily

Do we close this issue now that #1150 is merged?

jeffdaily avatar Oct 29 '24 21:10 jeffdaily

@jeffdaily Have you verified that the existing fp8 routines work on ROCm? Unfortunately we still dont have ROCm runners in CI/CD and at least personally dont have much access to RoCM machines to test

drisspg avatar Oct 29 '24 22:10 drisspg

I think "done" here would mean that we have float8 training working on ROCm with compelling reproducible perf/accuracy benchmarks that we can communicate. I don't feel strongly if that should be tracked in this issue or a higher level issue - if you prefer to close this one, happy to track this elsewhere.

vkuzo avatar Oct 29 '24 22:10 vkuzo

@vkuzo I'd look into this. Do you have particular workloads in mind? I'd love to leverage your training scripts for apple-to-apple comparison.

petrex avatar Nov 11 '24 23:11 petrex

@vkuzo I'd look into this. Do you have particular workloads in mind? I'd love to leverage your training scripts for apple-to-apple comparison.

Awesome!

For microbenchmarks, we have a couple of representative benchmarks where it would be great to measure how float8 gemm compares to bfloat16 on ROCm:

  • https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_matmul.py for benchmarking only torch._scaled_mm
  • https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_linear_float8.py for benchmarking torch.nn.Linear, which includes casting/scaling

For multi GPU real training runs with float8 on LLaMa model variants, we have been using https://github.com/pytorch/torchtitan, having good ROCm numbers there would be fairly representative of real workloads

Let me know how I can help!

vkuzo avatar Nov 12 '24 03:11 vkuzo

If https://github.com/pytorch/pytorch/pull/140856 lands we're closer to fully passing ao's unit test suite. That PR only works for e4m3 currently. I'll either fix it up to also allow e5m2 or do a follow-up PR. There are a handful of ao UTs that expect to use e5m2 rowwise gemm.

jeffdaily avatar Nov 21 '24 18:11 jeffdaily

Support is still missing for float8_dynamic_activation_float8_weight and float8_static_activation_float8_weight on ROCm which both throw the below error:

File "/usr/local/lib/python3.12/dist-packages/torchao/float8/inference.py", line 90, in addmm_float8_unwrapped_inference output = torch._scaled_mm( RuntimeError: false INTERNAL ASSERT FAILED at "/app/pytorch/aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.

This is from cloning the main torchao and pytorch repos and building from source on an MI300X with ROCm 6.3.

clintg6 avatar Feb 12 '25 00:02 clintg6

@clintg6 you need to specify the dtype for these to be the nuz variant e.g.:

float8_dynamic_activation_float8_weight(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz)

drisspg avatar Feb 12 '25 01:02 drisspg

@clintg6 you need to specify the dtype for these to be the nuz variant e.g.:

float8_dynamic_activation_float8_weight(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz)

thanks @drisspg .

@clintg6 ROCm supports 1. fnuz variants in MI300x. 2. OCP F8 variants in MI350/Navi4.

petrex avatar Feb 12 '25 01:02 petrex

is this issue under active development?

ehartford avatar Aug 02 '25 01:08 ehartford

@ehartford , curious on what you are trying to do?

I'm trying to find an AMD machine, after I find one I will test and report back with current status.

vkuzo avatar Aug 04 '25 18:08 vkuzo

You can use mine, if you like

ehartford avatar Aug 04 '25 19:08 ehartford

IIRC this should work out of the box. if not pls let me know.

petrex avatar Aug 04 '25 19:08 petrex

I'm trying to train in fp8 on mi300x using Axolotl

I'll make a repro tonight

ehartford avatar Aug 04 '25 20:08 ehartford

I finally got my hands on a machine with MI300X GPUs, and things in fact do just work. Here is a PR updating our training benchmarks with AMD results: https://github.com/pytorch/ao/pull/2736

vkuzo avatar Aug 11 '25 17:08 vkuzo

I ran some additional roofline microbenchmarks, which look good: https://github.com/pytorch/ao/pull/2737, the performance I see on MI300X and NVIDIA H100 is in the same ballpark in terms of % roofline achieved and % speedup vs bfloat16.

vkuzo avatar Aug 11 '25 19:08 vkuzo

@ehartford , please let us know if something is not working for you and we'll help look into it! Closing this issue for now.

vkuzo avatar Aug 12 '25 00:08 vkuzo