is NVFP4 not supported for rtx 50 series?
Hi I locally compiled branch release 2.8.
when I tried to use nvfp4 on rtx 50 series it gave me error
/home/aza/workspace/projects/nvfp4/TransformerEngine/transformer_engine/common/util/nvfp4_transpose.cuh:234 in function mul_cvt_bf16_to_fp4_4x_with_rn (thread (95,0,0), block (2,2,0)): FP4 cvt PTX instructions are architecture-specific. Try recompiling with sm_XXXa instead of sm_XXX.
Hello @yash3056. The default NVFP4 recipe uses stochastic rounding, which is accelerated only on SM 100 and SM 103 (B100/B200/B300). You should be able to disable stochastic rounding and try to compile for the sm120a architecture as the round-to-nearest mode is supported there. @Oleg-Goncharov we should add a fallback path for the stochastic rounding on non-10x cards.
Actually, looking at the error message the kernel which failed is the round to nearest kernel (the stochastic rounding will fail too but it did not get there yet). This is because we check for the SM 10X features, while we should be checking for 10X and 12X there. I can open a quick PR tomorrow to fix that. In the meantime you can try removing the condition here: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/nvfp4_transpose.cuh#L201 (just make sure you compile for SM120a).
@ptrendx will disabling stochastic rounding and using round-to-nearest mode will not cause problem with model performance to degrade?
btw can you tell me how to compile for SM120a? I am trying to use this flag export CMAKE_CUDA_ARCHITECTURES="120a"
but it is not building for sm_120a
Hello @yash3056, stochastic rounding does impact the model performance, but its contribution depends on the model size, quantization granularity, and whether other techniques are used (e.g., the random Hadamard transform, which is currently the part of the NVFP4 recipe).
To build the TE for sm_120a, please try export NVTE_CUDA_ARCHS=120a.
I removed condition from https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/util/nvfp4_transpose.cuh#L201 and compiled it using 120a cm flag but
I am getting this error
Error: Failed to set Shared Memory size.
most probable cause is difference between architecture as such it needs a fallback for rtx 50 series, just removing condition will not make it work.
to replicate it
CODE:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, NVFP4BlockScaling
# Setup
device = torch.device("cuda")
fp4_format = Format.E2M1
fp4_recipe = NVFP4BlockScaling(fp4_format=fp4_format)
print("NVFP4 Matrix Multiplication Test")
print(f"Format: {fp4_format}")
print("Recipe: NVFP4BlockScaling\n")
# Create input
torch.manual_seed(42)
batch_size, seq_len, hidden_dim = 4, 128, 768
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.bfloat16).cuda()
# Create linear layer
linear = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=torch.bfloat16).cuda()
print(f"Input shape: {x.shape}")
print(f"Input dtype: {x.dtype}")
print(f"Weight shape: {linear.weight.shape}\n")
# BF16 baseline
with torch.no_grad():
out_bf16 = linear(x)
# NVFP4 forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp4_recipe):
out_fp4 = linear(x)
print(f"Output shape: {out_fp4.shape}")
print(f"BF16 output sample: {out_bf16[0, 0, :5]}")
print(f"FP4 output sample: {out_fp4[0, 0, :5]}\n")
# Error metrics
rel_error = torch.abs(out_bf16 - out_fp4) / (torch.abs(out_bf16) + 1e-5)
print(f"Mean relative error: {rel_error.mean().item():.6f}")
print(f"Max relative error: {rel_error.max().item():.6f}")
Error:
NVFP4 Matrix Multiplication Test
Format: Format.E2M1
Recipe: NVFP4BlockScaling
Input shape: torch.Size([4, 128, 768])
Input dtype: torch.bfloat16
Weight shape: torch.Size([768, 768])
Error: Failed to set Shared Memory size.
Traceback (most recent call last):
File "/home/aza/workspace/projects/nvfp4/TransformerEngine/../test.py", line 32, in <module>
out_fp4 = linear(x)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
return fn(*args, **kwargs)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/transformer_engine/pytorch/module/linear.py", line 1482, in forward
out = linear_fn(*args)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/torch/autograd/function.py", line 582, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/transformer_engine/pytorch/module/linear.py", line 254, in forward
weightmat = module.get_weight_workspace(
tensor=weight,
...<5 lines>...
workspace_dtype=activation_dtype,
)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/transformer_engine/pytorch/module/base.py", line 1428, in get_weight_workspace
out = quantizer.quantize(tensor, dtype=workspace_dtype)
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/transformer_engine/pytorch/tensor/quantized_tensor.py", line 203, in quantize
return _QuantizeFunc.forward(None, tensor, self)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/aza/miniforge3/envs/transformer-engine/lib/python3.13/site-packages/transformer_engine/pytorch/tensor/quantized_tensor.py", line 282, in forward
return tex.quantize(tensor, quantizer)
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: invalid argument
Search for `cudaErrorInvalidValue' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Ok, this is actually a slightly larger can of worms than I anticipated (in addition to the problem you flagged there are random things like CMake<4.0.2 not understanding "f" in the SM arch and erroring out 😞). Will continue looking at it on Monday.
Ok, this is actually a slightly larger can of worms than I anticipated (in addition to the problem you flagged there are random things like CMake<4.0.2 not understanding "f" in the SM arch and erroring out 😞). Will continue looking at it on Monday.
Some ptx is not supported on SM_120, such as 'cvt.rs.satfinite.e2m1x4.f32' 🥲 https://docs.nvidia.com/cuda/parallel-thread-execution/#:~:text=sm_103a-,cvt,-.rs%7B.e2m1x4/.e4m3x4
rtx 50 series instruction are closer to ada series gpu
The PR #2279 should help to move this forward a little bit - there is still a problem with the cvt.rs instructions (we need to add emulation for those unfortunately and that is not part of this PR), but this should hopefully enable you to run when setting the recipe to not do the stochastic rounding. Could you try @yash3056?
@ptrendx the nvfp4 matmul is working now.
Here is the update code (the change in it is only torch.no_grad() added during inference)
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, NVFP4BlockScaling
# NVFP4 Matrix Multiplication Test
# This test demonstrates NVFP4 quantization for linear layers
# Key requirements:
# - Dimensions should be multiples of 16 for optimal NVFP4 block scaling (16x16 blocks)
# - fp8_autocast context must wrap all FP4 operations
# - Use torch.no_grad() when not training to avoid gradient computation
# Setup
device = torch.device("cuda")
fp4_format = Format.E2M1
fp4_recipe = NVFP4BlockScaling(fp4_format=fp4_format)
print("NVFP4 Matrix Multiplication Test")
print(f"Format: {fp4_format}")
print("Recipe: NVFP4BlockScaling\n")
# Create input - dimensions should be multiples of 16 for NVFP4 block scaling
torch.manual_seed(42)
batch_size, seq_len, hidden_dim = 4, 128, 768 # 768 is divisible by 16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.bfloat16).cuda()
# Create linear layer
linear = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=torch.bfloat16).cuda()
print(f"Input shape: {x.shape}")
print(f"Input dtype: {x.dtype}")
print(f"Weight shape: {linear.weight.shape}\n")
# BF16 baseline
with torch.no_grad():
out_bf16 = linear(x)
# NVFP4 forward pass - wrap all FP4 operations in the context
print("Running NVFP4 forward pass...")
with torch.no_grad():
with te.fp8_autocast(enabled=True, fp8_recipe=fp4_recipe):
out_fp4 = linear(x)
print(f"Output shape: {out_fp4.shape}")
print(f"BF16 output sample: {out_bf16[0, 0, :5]}")
print(f"FP4 output sample: {out_fp4[0, 0, :5]}\n")
# Error metrics
abs_error = torch.abs(out_bf16 - out_fp4)
rel_error = abs_error / (torch.abs(out_bf16) + 1e-5)
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
print(f"Max absolute error: {abs_error.max().item():.6f}")
print(f"Mean relative error: {rel_error.mean().item():.6f}")
print(f"Median relative error: {rel_error.median().item():.6f}")
# Also compute RMSE for a more stable metric
rmse = torch.sqrt(torch.mean((out_bf16 - out_fp4) ** 2))
print(f"RMSE: {rmse.item():.6f}")
# Signal-to-noise ratio
signal_power = torch.mean(out_bf16 ** 2)
noise_power = torch.mean((out_bf16 - out_fp4) ** 2)
snr = 10 * torch.log10(signal_power / noise_power)
print(f"SNR: {snr.item():.2f} dB")
Output:
NVFP4 Matrix Multiplication Test
Format: Format.E2M1
Recipe: NVFP4BlockScaling
Input shape: torch.Size([4, 128, 768])
Input dtype: torch.bfloat16
Weight shape: torch.Size([768, 768])
Running NVFP4 forward pass...
Output shape: torch.Size([4, 128, 768])
BF16 output sample: tensor([-0.3867, -0.0835, 0.3047, -0.9023, -0.4121], device='cuda:0',
dtype=torch.bfloat16)
FP4 output sample: tensor([-0.3223, -0.1543, 0.3105, -0.8945, -0.4609], device='cuda:0',
dtype=torch.bfloat16)
Mean absolute error: 0.073730
Max absolute error: 0.478516
Mean relative error: 1.031250
Median relative error: 0.146484
RMSE: 0.092285
SNR: 16.75 dB
in short it is working
@ptrendx nvfp4 matmul 现已开始工作。
这是更新代码(其中的变化仅在推理期间添加了 torch.no_grad())
import torch import transformer_engine.pytorch as te from transformer_engine.common.recipe import Format, NVFP4BlockScaling # NVFP4 Matrix Multiplication Test # This test demonstrates NVFP4 quantization for linear layers # Key requirements: # - Dimensions should be multiples of 16 for optimal NVFP4 block scaling (16x16 blocks) # - fp8_autocast context must wrap all FP4 operations # - Use torch.no_grad() when not training to avoid gradient computation # Setup device = torch.device("cuda") fp4_format = Format.E2M1 fp4_recipe = NVFP4BlockScaling(fp4_format=fp4_format) print("NVFP4 Matrix Multiplication Test") print(f"Format: {fp4_format}") print("Recipe: NVFP4BlockScaling\n") # Create input - dimensions should be multiples of 16 for NVFP4 block scaling torch.manual_seed(42) batch_size, seq_len, hidden_dim = 4, 128, 768 # 768 is divisible by 16 x = torch.randn(batch_size, seq_len, hidden_dim, dtype=torch.bfloat16).cuda() # Create linear layer linear = te.Linear(hidden_dim, hidden_dim, bias=True, params_dtype=torch.bfloat16).cuda() print(f"Input shape: {x.shape}") print(f"Input dtype: {x.dtype}") print(f"Weight shape: {linear.weight.shape}\n") # BF16 baseline with torch.no_grad(): out_bf16 = linear(x) # NVFP4 forward pass - wrap all FP4 operations in the context print("Running NVFP4 forward pass...") with torch.no_grad(): with te.fp8_autocast(enabled=True, fp8_recipe=fp4_recipe): out_fp4 = linear(x) print(f"Output shape: {out_fp4.shape}") print(f"BF16 output sample: {out_bf16[0, 0, :5]}") print(f"FP4 output sample: {out_fp4[0, 0, :5]}\n") # Error metrics abs_error = torch.abs(out_bf16 - out_fp4) rel_error = abs_error / (torch.abs(out_bf16) + 1e-5) print(f"Mean absolute error: {abs_error.mean().item():.6f}") print(f"Max absolute error: {abs_error.max().item():.6f}") print(f"Mean relative error: {rel_error.mean().item():.6f}") print(f"Median relative error: {rel_error.median().item():.6f}") # Also compute RMSE for a more stable metric rmse = torch.sqrt(torch.mean((out_bf16 - out_fp4) ** 2)) print(f"RMSE: {rmse.item():.6f}") # Signal-to-noise ratio signal_power = torch.mean(out_bf16 ** 2) noise_power = torch.mean((out_bf16 - out_fp4) ** 2) snr = 10 * torch.log10(signal_power / noise_power) print(f"SNR: {snr.item():.2f} dB")输出:
NVFP4 Matrix Multiplication Test Format: Format.E2M1 Recipe: NVFP4BlockScaling Input shape: torch.Size([4, 128, 768]) Input dtype: torch.bfloat16 Weight shape: torch.Size([768, 768]) Running NVFP4 forward pass... Output shape: torch.Size([4, 128, 768]) BF16 output sample: tensor([-0.3867, -0.0835, 0.3047, -0.9023, -0.4121], device='cuda:0', dtype=torch.bfloat16) FP4 output sample: tensor([-0.3223, -0.1543, 0.3105, -0.8945, -0.4609], device='cuda:0', dtype=torch.bfloat16) Mean absolute error: 0.073730 Max absolute error: 0.478516 Mean relative error: 1.031250 Median relative error: 0.146484 RMSE: 0.092285 SNR: 16.75 dB简而言之,它正在发挥作用
well,what about fp4 fuse attention?or te.fp8_model_init()?
@yash3056 Were you able to get this working for the backward pass also, or only the forward pass?
I've been able to build off the main branch and get the forward pass working but the backward pass is showing an unsupported GEMM.
Stacktrace in details.
Traceback (most recent call last):
File "/app/minimal_example.py", line 121, in <module>
main()
File "/app/minimal_example.py", line 107, in main
loss.backward()
File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 315, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 895, in backward
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/linear.py", line 875, in wgrad_gemm
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpp_extensions/gemm.py", line 149, in general_gemm
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /app/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:744 in function cublas_gemm: Assertion failed: status != CUBLAS_STATUS_NOT_SUPPORTED. Unable to find suitable cuBLAS GEMM algorithm
for anyone wondering, this is how you can get NVFP4BlockScaling to work on SM120:
recipe = NVFP4BlockScaling(disable_rht= True ,disable_stochastic_rounding= True )
@ptrendx any progress on this?
@alint77 isn't stochastic rounding required for FP4 training?
"stochastic rounding ensures that gradients are rounded up or down randomly, with probabilities proportional to how close a number lies between two representable values. This step is essential for reducing rounding bias, maintaining gradient flow during training, and ultimately improving model accuracy."
https://developer.nvidia.com/blog/nvfp4-trains-with-precision-of-16-bit-and-speed-and-efficiency-of-4-bit/
Is the team not planning to implement stochastic rounding for our SM120 blackwell cards?
@vgoklani yeah I also thought it's mandatory but turns out it isn't and the model actually does converge although I haven't tried with model_init set to nvfp4.
you're welcome to try it on my nanogpt-fp8 repo
@alint77 thanks for the followup, Is there a timeline to implement stochastic rounding etc for the SM120 blackwell series? We are specifically interested in nvfp4 and will use this with 4x RTX Blackwell MaxQ cards. For your nanogpt-fp8-repo is there a baseline that we can compare to? When you say the model actually does "converge" - how does that compare to a similar model trained on an H200? We want to get these things sorted out before we start the training runs, as those are very time-consuming. Thanks!
@alint77 isn't stochastic rounding required for FP4 training?
"stochastic rounding ensures that gradients are rounded up or down randomly, with probabilities proportional to how close a number lies between two representable values. This step is essential for reducing rounding bias, maintaining gradient flow during training, and ultimately improving model accuracy."
https://developer.nvidia.com/blog/nvfp4-trains-with-precision-of-16-bit-and-speed-and-efficiency-of-4-bit/Is the team not planning to implement stochastic rounding for our SM120 blackwell cards?
The stochastic rounding is a SM101/SM110 hardware/cuda feature, pytorch have discussed a extra step or similar to implement something equivalent but slower since it will inherently require a higher precision to determine the rounding fudge. Nvidia pretty obviously, semi-brilliantly, semi-evil, figured out a way to create a hampered inference card (RTX Pros) that don't do training well so as to not hurt the DC training products (beyond the HBM).