Feature fast cast-only mxfp8
Description
This pull request involves efficient implementations for mxfp8 quantize on casting only cases. It can increase the casting performance from 5%~ 20%.
It supports:
-
BF16orFP16as inputs -
E5M2orE4M3as outputs - gpu arch >=
sm_100 -
rowwiseorrow- & col-wise
Performance gain:
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Changes
Please list the changes introduced in this PR:
- Added an environment
ENABLE_CAST_ONLYto select optimized kernel. If optimized kernel doesn't support provided inputs, it will fallback to original kernels, automatically.- If
ENABLE_CAST_ONLYis not set or is set to0, then original kernels will be used.
- If
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
Steps to reproduce performance numbers
- start a container with image nvcr.io/nvidia/pytorch:25.06-py3 on GB200 clusters
- uninstall the pre-installed TE
pip uninstall -y transformer_engine - manually installed this branch with
export PYTHONUSERBASE=/tmp/python unset PIP_CONSTRAINT && NVTE_CUDA_ARCHS="100a" NVTE_BUILD_THREADS_PER_JOB=8 NVTE_FRAMEWORK=pytorch pip install --no-build-isolation -v -e ./TransformerEngine - Run the following scripts with NCU, which will tell you the kernel duration and memory bandwidth
ncu --section=MemoryWorkloadAnalysis --section=SpeedOfLight --clock-control=none --nvtx --nvtx-include="Update Quantized/" --nvtx-include="reference kernel/" python quantize.py
# quantize.py
import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp"
os.environ["TRITON_CACHE_DIR"] = "/tmp"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"
if "USER" not in os.environ:
os.environ["USER"] = "you"
import torch
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
import argparse
torch.cuda.manual_seed(233376)
def run(args):
check_consistency = args.check_consistency
direction = {"rowwise": not args.no_rowwise, "colwise": not args.no_colwise}
src_dtype = args.src_dtype
dst_dtype = args.dst_dtype
size_h, size_w = args.size_h, args.size_w
msg_candidates = {"TrueTrue": "rowwise and colwise",
"TrueFalse": "rowwise",
"FalseTrue": "colwise",
"FalseFalse": None}
msg = msg_candidates[f"{direction['rowwise']}{direction['colwise']}"]
if msg is None:
raise ValueError(f"Invalid direction: {direction}")
print("=" * 120)
print(f"checking {msg}, "
f"src_dtype: {src_dtype}, dst_dtype: {dst_dtype}, size_h: {size_h}, size_w: {size_w}")
print("=" * 120)
with torch.cuda.nvtx.range("Ctor"):
quantizer = MXFP8Quantizer(
fp8_dtype=dst_dtype,
rowwise=direction["rowwise"],
columnwise=direction["colwise"],
)
with torch.cuda.nvtx.range("Create Input"):
bf16_tensor = torch.randn(size_h, size_w, dtype=src_dtype, device="cuda")
# bf16_tensor = torch.arange(size_h * size_w, dtype=src_dtype, device="cuda").reshape(size_h, size_w)
# # Print every element in bf16_tensor
# print("Elements of bf16_tensor:")
# for i in range(bf16_tensor.shape[0]):
# print("row: ", i, end=": ")
# for j in range(bf16_tensor.shape[1]):
# print(f"{bf16_tensor[i, j].item():.4f}\t", end="")
# print()
# amax = torch.abs(bf16_tensor).amax(axis=0, keepdim=True)
# print(amax)
if check_consistency:
with torch.cuda.nvtx.range("reference"):
fp8_tensor_ref = quantizer.make_empty(
bf16_tensor.shape,
dtype=bf16_tensor.dtype,
device=bf16_tensor.device,
)
with torch.cuda.nvtx.range("reference kernel"):
quantizer.update_quantized(bf16_tensor, fp8_tensor_ref)
with torch.cuda.nvtx.range("Make Empty"):
fp8_tensor = quantizer.make_empty(
bf16_tensor.shape,
dtype=bf16_tensor.dtype,
device=bf16_tensor.device,
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if check_consistency:
os.environ["ENABLE_CAST_ONLY"] = "1"
start.record()
with torch.cuda.nvtx.range("Update Quantized"):
quantizer.update_quantized(bf16_tensor, fp8_tensor)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end)
io_bytes = size_h * size_w * 2
io_bytes += size_h * size_w * 1
io_bytes += size_h * (size_w // 32) * 1
print(f"Io Bytes: {io_bytes / 1e6} MB")
print(f"Duration: {ms} ms")
print(f"Bandwidth: {(io_bytes * 1e-9) / (ms * 1e-3)} GB/s")
# print(fp8_tensor)
if check_consistency:
# print(fp8_tensor_ref)
if direction["rowwise"]:
torch.testing.assert_close(fp8_tensor._rowwise_data, fp8_tensor_ref._rowwise_data)
print("rowwise data passed")
# print(fp8_tensor._rowwise_scale_inv.shape)
# for i in range(fp8_tensor._rowwise_scale_inv.shape[0]):
# print(f"row: {i}", end=": ")
# for j in range(fp8_tensor._rowwise_scale_inv.shape[1]):
# print(f"{fp8_tensor._rowwise_scale_inv[i, j].item():d},", end="")
# print("")
# print("-------------ref tensor-------------------")
# for i in range(fp8_tensor_ref._rowwise_scale_inv.shape[0]):
# print(f"row: {i}", end=": ")
# for j in range(fp8_tensor_ref._rowwise_scale_inv.shape[1]):
# print(f"{fp8_tensor_ref._rowwise_scale_inv[i, j].item():d},", end="")
# print("")
torch.testing.assert_close(fp8_tensor._rowwise_scale_inv, fp8_tensor_ref._rowwise_scale_inv)
print("rowwise scale_inv passed")
if direction["colwise"]:
torch.testing.assert_close(fp8_tensor._columnwise_data, fp8_tensor_ref._columnwise_data)
print("colwise data passed")
torch.testing.assert_close(fp8_tensor._columnwise_scale_inv, fp8_tensor_ref._columnwise_scale_inv)
print("colwise scale_inv passed")
torch.testing.assert_close(fp8_tensor, fp8_tensor_ref)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Profile MXFP8 quantization")
parser.add_argument("--no_rowwise", action="store_true", default=False, help="Enable rowwise quantization")
parser.add_argument("--no_colwise", action="store_true", default=False, help="Enable colwise quantization")
parser.add_argument("--src_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16"], help="Source dtype")
parser.add_argument("--dst_dtype", type=str, default="kFloat8E4M3", choices=["kFloat8E4M3", "kFloat8E5M2"], help="Destination dtype")
parser.add_argument("--size_h", type=int, default=4096, help="Input tensor height")
parser.add_argument("--size_w", type=int, default=7168, help="Input tensor width")
parser.add_argument("--check_consistency", action="store_true", default=True, help="Check consistency")
args = parser.parse_args()
if args.src_dtype == "bfloat16":
src_dtype = torch.bfloat16
args.src_dtype = src_dtype
elif args.src_dtype == "float16":
src_dtype = torch.float16
args.src_dtype = src_dtype
elif args.src_dtype == "float32":
src_dtype = torch.float32
args.src_dtype = src_dtype
else:
raise ValueError(f"Unsupported src_dtype: {args.src_dtype}")
if args.dst_dtype == "kFloat8E4M3":
dst_dtype = tex.DType.kFloat8E4M3
args.dst_dtype = dst_dtype
elif args.dst_dtype == "kFloat8E5M2":
dst_dtype = tex.DType.kFloat8E5M2
args.dst_dtype = dst_dtype
else:
raise ValueError(f"Unsupported dst_dtype: {args.dst_dtype}")
run(args)
Hello @Jianbing-D, thanks for your contribution! Would you please check why the PR is failing CI checks? If you don't find the reason, please request support from the reviewers. Thank you!
Hello @Jianbing-D, thanks for your contribution! Would you please check why the PR is failing CI checks? If you don't find the reason, please request support from the reviewers. Thank you!
Hi @nvMelissa , Thank you for pointing out the CI failures. I have made some modifications to fix them.
Please review and trigger the CI.
/te-ci
/te-ci
Hi @Oleg-Goncharov , looks like there were issues on the CI system.
Installing collected packages: nvidia-mathdx, nvidia-cusparselt-cu12, mpmath, triton, sympy, pybind11-global, pybind11, nvidia-nvtx-cu12, nvidia-nvshmem-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, MarkupSafe, filelock, einops, onnx, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, jinja2, onnx_ir, nvidia-cusolver-cu12, torch, onnxscript
ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device
Hi @Jianbing-D, could you please resolve the merge conflicts once more? Sorry for asking to do this again, but we refactored the cast kernels to organize them into clearer groups, so your MXFP8 kernel should now live under /common/cast/mxfp8/specialized/ with the appropriate kernel filtering. CI is fixed, so once the conflicts are resolved I can merge. Thanks a lot!
Hi @Jianbing-D, could you please resolve the merge conflicts once more? Sorry for asking to do this again, but we refactored the cast kernels to organize them into clearer groups, so your MXFP8 kernel should now live under
/common/cast/mxfp8/specialized/with the appropriate kernel filtering. CI is fixed, so once the conflicts are resolved I can merge. Thanks a lot!
Hi @Oleg-Goncharov , I have refactored this branch, and put my kernel inside the folder you suggested. Please review it. Thanks.
Greptile Summary
- Adds optimized MXFP8 cast-only quantization kernels gated by
ENABLE_CAST_ONLYenvironment variable, achieving 5-20% performance improvements for SM 10.0+ GPUs - Implements specialized kernels for rowwise and bidimensional scaling using TMA operations, memory swizzling, and new PTX intrinsics (
mul_cvt_4x,fma_f32_f16/bf16) - Critical syntax errors in PTX inline assembly will cause compilation failures
Confidence Score: 2/5
- This PR contains compilation-breaking syntax errors that must be fixed before merging
- Score reflects critical syntax errors in
ptx.cuh(missing commas in inline assembly at multiple locations) that will cause compilation failures. The architectural design and optimization approach are sound, but the code cannot compile in its current state. -
transformer_engine/common/util/ptx.cuhrequires immediate attention due to syntax errors that will prevent compilation
Important Files Changed
| Filename | Overview |
|---|---|
| transformer_engine/common/util/ptx.cuh | Added new PTX intrinsics for SM 10.0+ including mul_cvt_4x, fma_f32_f16/bf16, and reduction operations; contains critical syntax errors with missing commas in inline assembly |
| transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh | New file implementing optimized cast-only MXFP8 quantization kernels for rowwise and bidimensional scaling with TMA, swizzling, and environment-gated activation |
| transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | Integrated specialized cast-only kernels into existing quantization dispatcher with environment variable gating and automatic fallback to original kernels |
Sequence Diagram
sequenceDiagram
participant User
participant Dispatcher as "quantize() dispatcher"
participant SpecCheck as "specialized::hasSpec()"
participant Kernel as "quantize_mxfp8_kernel_cast_only"
participant PTX as "PTX intrinsics"
User->>Dispatcher: Call quantize with input tensor
Dispatcher->>SpecCheck: Check if specialized kernel available
SpecCheck->>SpecCheck: Read ENABLE_CAST_ONLY env var
alt Specialized kernel available
SpecCheck-->>Dispatcher: Return true
alt Rowwise scaling
Dispatcher->>Dispatcher: Configure CastTraits<IType, OType, true, false>
Dispatcher->>Kernel: Launch kernel (grid, block, smem)
Kernel->>Kernel: Load input data chunks
Kernel->>PTX: Compute amax via abs_max_2x
Kernel->>PTX: Convert to e8m0 scale
Kernel->>PTX: Scale and convert via mul_cvt_4x
Kernel->>Kernel: Write FP8 output
Kernel-->>Dispatcher: Complete
else Bidimensional scaling
Dispatcher->>Dispatcher: Configure CastTraits<IType, OType, true, true>
Dispatcher->>Dispatcher: Create TMA tensor maps
Dispatcher->>Kernel: Launch kernel with TMA maps
Kernel->>Kernel: Load via TMA with swizzling
Kernel->>PTX: Compute rowwise and colwise amax
Kernel->>PTX: Scale and convert to FP8
Kernel->>Kernel: Write rowwise and colwise outputs
Kernel-->>Dispatcher: Complete
end
Dispatcher-->>User: Return quantized tensor
else No specialized kernel
SpecCheck-->>Dispatcher: Return false
Dispatcher->>Dispatcher: Use original quantization path
Dispatcher-->>User: Return quantized tensor
end
/te-ci
Hi @Oleg-Goncharov
Seems CI encountered storage issue again.