[ROCm] float8 does not work
Hi @hongxiayang @hliuca ,
It seems like float8 training using torchao.float8 is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8?
Attempting Install From Nightly
From using the ROCm nightly torchao wheel, the torchao.float8 module is not present
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/rocm6.2
python -c "import torchao; print(dir(torchao))"
['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'apply_dynamic_quant', 'apply_weight_only_int8_quant', 'dtypes', 'kernel', 'quantization']
Attempting Install From Source
From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install git+https://github.com/pytorch/ao.git
Eager Mode Error
tensor_out = addmm_float8_unwrapped(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.
Compile Mode Error
tmp15 = 448.0
tmp16 = triton_helpers.minimum(tmp14, tmp15)
tmp17 = tmp16.to(tl.float8e4nv)
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Reprod Script is From The torchao.float8 README Example
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
Can you make sure to set: https://github.com/pytorch/ao/blob/e7b33bc91c831d10249c1222c8b4b667f18f28b7/torchao/float8/config.py#L246 to True
Here are my thoughts on what we need to do to enable ROCm support for float8:
- ensure
torch._scaled_mm's path for ROCm is fast an accurate - there is a config setting for the
nuzfloat8 flavors here (https://github.com/pytorch/ao/blob/0b71b8d38f6b238e510876bf2d75b4280a651175/torchao/float8/config.py#L246), but it's not tested at the moment. We should enable testing across all of our test suite, first locally and then in CI. - get e2e performance/accuracy to be good, measured by benchmarks on real workloads
#1142 might help partially fix this by automating the setting of these types depending on the hardware deteted.
@OrenLeung please install latest torchao and try again. #1142 was reverted but #1150 (reland) merged.
Do we close this issue now that #1150 is merged?
@jeffdaily Have you verified that the existing fp8 routines work on ROCm? Unfortunately we still dont have ROCm runners in CI/CD and at least personally dont have much access to RoCM machines to test
I think "done" here would mean that we have float8 training working on ROCm with compelling reproducible perf/accuracy benchmarks that we can communicate. I don't feel strongly if that should be tracked in this issue or a higher level issue - if you prefer to close this one, happy to track this elsewhere.
@vkuzo I'd look into this. Do you have particular workloads in mind? I'd love to leverage your training scripts for apple-to-apple comparison.
@vkuzo I'd look into this. Do you have particular workloads in mind? I'd love to leverage your training scripts for apple-to-apple comparison.
Awesome!
For microbenchmarks, we have a couple of representative benchmarks where it would be great to measure how float8 gemm compares to bfloat16 on ROCm:
- https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_matmul.py for benchmarking only
torch._scaled_mm - https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_linear_float8.py for benchmarking
torch.nn.Linear, which includes casting/scaling
For multi GPU real training runs with float8 on LLaMa model variants, we have been using https://github.com/pytorch/torchtitan, having good ROCm numbers there would be fairly representative of real workloads
Let me know how I can help!
If https://github.com/pytorch/pytorch/pull/140856 lands we're closer to fully passing ao's unit test suite. That PR only works for e4m3 currently. I'll either fix it up to also allow e5m2 or do a follow-up PR. There are a handful of ao UTs that expect to use e5m2 rowwise gemm.
Support is still missing for float8_dynamic_activation_float8_weight and float8_static_activation_float8_weight on ROCm which both throw the below error:
File "/usr/local/lib/python3.12/dist-packages/torchao/float8/inference.py", line 90, in addmm_float8_unwrapped_inference output = torch._scaled_mm( RuntimeError: false INTERNAL ASSERT FAILED at "/app/pytorch/aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.
This is from cloning the main torchao and pytorch repos and building from source on an MI300X with ROCm 6.3.
@clintg6 you need to specify the dtype for these to be the nuz variant e.g.:
float8_dynamic_activation_float8_weight(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz)
@clintg6 you need to specify the dtype for these to be the nuz variant e.g.:
float8_dynamic_activation_float8_weight(torch.float8_e4m3fnuz, torch.float8_e4m3fnuz)
thanks @drisspg .
@clintg6 ROCm supports 1. fnuz variants in MI300x. 2. OCP F8 variants in MI350/Navi4.
is this issue under active development?
@ehartford , curious on what you are trying to do?
I'm trying to find an AMD machine, after I find one I will test and report back with current status.
You can use mine, if you like
IIRC this should work out of the box. if not pls let me know.
I'm trying to train in fp8 on mi300x using Axolotl
I'll make a repro tonight
I finally got my hands on a machine with MI300X GPUs, and things in fact do just work. Here is a PR updating our training benchmarks with AMD results: https://github.com/pytorch/ao/pull/2736
I ran some additional roofline microbenchmarks, which look good: https://github.com/pytorch/ao/pull/2737, the performance I see on MI300X and NVIDIA H100 is in the same ballpark in terms of % roofline achieved and % speedup vs bfloat16.
@ehartford , please let us know if something is not working for you and we'll help look into it! Closing this issue for now.