TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Feature fast cast-only mxfp8

Open Jianbing-D opened this issue 5 months ago • 6 comments

Description

This pull request involves efficient implementations for mxfp8 quantize on casting only cases. It can increase the casting performance from 5%~ 20%.

It supports:

  • BF16 or FP16 as inputs
  • E5M2 or E4M3 as outputs
  • gpu arch >= sm_100
  • rowwise or row- & col-wise

Performance gain: image image image image

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added an environment ENABLE_CAST_ONLY to select optimized kernel. If optimized kernel doesn't support provided inputs, it will fallback to original kernels, automatically.
    1. If ENABLE_CAST_ONLY is not set or is set to 0, then original kernels will be used.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

Jianbing-D avatar Aug 12 '25 04:08 Jianbing-D

Steps to reproduce performance numbers

  1. start a container with image nvcr.io/nvidia/pytorch:25.06-py3 on GB200 clusters
  2. uninstall the pre-installed TE pip uninstall -y transformer_engine
  3. manually installed this branch with export PYTHONUSERBASE=/tmp/python unset PIP_CONSTRAINT && NVTE_CUDA_ARCHS="100a" NVTE_BUILD_THREADS_PER_JOB=8 NVTE_FRAMEWORK=pytorch pip install --no-build-isolation -v -e ./TransformerEngine
  4. Run the following scripts with NCU, which will tell you the kernel duration and memory bandwidth
ncu --section=MemoryWorkloadAnalysis --section=SpeedOfLight  --clock-control=none --nvtx --nvtx-include="Update Quantized/" --nvtx-include="reference kernel/"  python quantize.py
# quantize.py

import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp"
os.environ["TRITON_CACHE_DIR"] = "/tmp"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"

if "USER" not in os.environ:
    os.environ["USER"] = "you"

import torch

from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex

import argparse

torch.cuda.manual_seed(233376)

def run(args):
    check_consistency = args.check_consistency

    direction = {"rowwise": not args.no_rowwise, "colwise": not args.no_colwise}
    src_dtype = args.src_dtype
    dst_dtype = args.dst_dtype
    size_h, size_w = args.size_h, args.size_w

    msg_candidates = {"TrueTrue": "rowwise and colwise",
                    "TrueFalse": "rowwise",
                    "FalseTrue": "colwise",
                    "FalseFalse": None}
    msg = msg_candidates[f"{direction['rowwise']}{direction['colwise']}"]
    if msg is None:
        raise ValueError(f"Invalid direction: {direction}")
    print("=" * 120)
    print(f"checking {msg}, "
          f"src_dtype: {src_dtype}, dst_dtype: {dst_dtype}, size_h: {size_h}, size_w: {size_w}")
    print("=" * 120)

    with torch.cuda.nvtx.range("Ctor"):
        quantizer = MXFP8Quantizer(
            fp8_dtype=dst_dtype,
            rowwise=direction["rowwise"],
            columnwise=direction["colwise"],
        )

    with torch.cuda.nvtx.range("Create Input"):
        bf16_tensor = torch.randn(size_h, size_w, dtype=src_dtype, device="cuda")
        # bf16_tensor = torch.arange(size_h * size_w, dtype=src_dtype, device="cuda").reshape(size_h, size_w)
        # # Print every element in bf16_tensor
        # print("Elements of bf16_tensor:")
        # for i in range(bf16_tensor.shape[0]):
        #     print("row: ", i, end=": ")
        #     for j in range(bf16_tensor.shape[1]):
        #         print(f"{bf16_tensor[i, j].item():.4f}\t", end="")
        #     print()
        # amax = torch.abs(bf16_tensor).amax(axis=0, keepdim=True)
        # print(amax)

    if check_consistency:
        with torch.cuda.nvtx.range("reference"):
            fp8_tensor_ref = quantizer.make_empty(
                bf16_tensor.shape,
                dtype=bf16_tensor.dtype,
                device=bf16_tensor.device,
            )
            with torch.cuda.nvtx.range("reference kernel"):
                quantizer.update_quantized(bf16_tensor, fp8_tensor_ref)


    with torch.cuda.nvtx.range("Make Empty"):
        fp8_tensor = quantizer.make_empty(
            bf16_tensor.shape,
            dtype=bf16_tensor.dtype,
            device=bf16_tensor.device,
        )


    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    if check_consistency:
        os.environ["ENABLE_CAST_ONLY"] = "1"
    start.record()
    with torch.cuda.nvtx.range("Update Quantized"):
        quantizer.update_quantized(bf16_tensor, fp8_tensor)
    end.record()

    torch.cuda.synchronize()

    ms = start.elapsed_time(end)


    io_bytes = size_h * size_w * 2
    io_bytes += size_h * size_w * 1
    io_bytes += size_h * (size_w // 32) * 1
    print(f"Io Bytes: {io_bytes / 1e6} MB")
    print(f"Duration: {ms} ms")
    print(f"Bandwidth: {(io_bytes * 1e-9) / (ms * 1e-3)} GB/s")

    # print(fp8_tensor)
    if check_consistency:
        # print(fp8_tensor_ref)

        if direction["rowwise"]:
            torch.testing.assert_close(fp8_tensor._rowwise_data, fp8_tensor_ref._rowwise_data)
            print("rowwise data passed")

            # print(fp8_tensor._rowwise_scale_inv.shape)
            # for i in range(fp8_tensor._rowwise_scale_inv.shape[0]):
            #     print(f"row: {i}", end=": ")
            #     for j in range(fp8_tensor._rowwise_scale_inv.shape[1]):
            #         print(f"{fp8_tensor._rowwise_scale_inv[i, j].item():d},", end="")
            #     print("")
            # print("-------------ref tensor-------------------")
            # for i in range(fp8_tensor_ref._rowwise_scale_inv.shape[0]):
            #     print(f"row: {i}", end=": ")
            #     for j in range(fp8_tensor_ref._rowwise_scale_inv.shape[1]):
            #         print(f"{fp8_tensor_ref._rowwise_scale_inv[i, j].item():d},", end="")
            #     print("")

            torch.testing.assert_close(fp8_tensor._rowwise_scale_inv, fp8_tensor_ref._rowwise_scale_inv)
            print("rowwise scale_inv passed")
        if direction["colwise"]:
            torch.testing.assert_close(fp8_tensor._columnwise_data, fp8_tensor_ref._columnwise_data)
            print("colwise data passed")
            torch.testing.assert_close(fp8_tensor._columnwise_scale_inv, fp8_tensor_ref._columnwise_scale_inv)
            print("colwise scale_inv passed")
        torch.testing.assert_close(fp8_tensor, fp8_tensor_ref)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Profile MXFP8 quantization")
    parser.add_argument("--no_rowwise", action="store_true", default=False, help="Enable rowwise quantization")
    parser.add_argument("--no_colwise", action="store_true", default=False, help="Enable colwise quantization")
    parser.add_argument("--src_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16"], help="Source dtype")
    parser.add_argument("--dst_dtype", type=str, default="kFloat8E4M3", choices=["kFloat8E4M3", "kFloat8E5M2"], help="Destination dtype")
    parser.add_argument("--size_h", type=int, default=4096, help="Input tensor height")
    parser.add_argument("--size_w", type=int, default=7168, help="Input tensor width")
    parser.add_argument("--check_consistency", action="store_true", default=True, help="Check consistency")
    args = parser.parse_args()

    if args.src_dtype == "bfloat16":
        src_dtype = torch.bfloat16
        args.src_dtype = src_dtype
    elif args.src_dtype == "float16":
        src_dtype = torch.float16
        args.src_dtype = src_dtype
    elif args.src_dtype == "float32":
        src_dtype = torch.float32
        args.src_dtype = src_dtype
    else:
        raise ValueError(f"Unsupported src_dtype: {args.src_dtype}")

    if args.dst_dtype == "kFloat8E4M3":
        dst_dtype = tex.DType.kFloat8E4M3
        args.dst_dtype = dst_dtype
    elif args.dst_dtype == "kFloat8E5M2":
        dst_dtype = tex.DType.kFloat8E5M2
        args.dst_dtype = dst_dtype
    else:
        raise ValueError(f"Unsupported dst_dtype: {args.dst_dtype}")

    run(args)

Jianbing-D avatar Aug 12 '25 04:08 Jianbing-D

Hello @Jianbing-D, thanks for your contribution! Would you please check why the PR is failing CI checks? If you don't find the reason, please request support from the reviewers. Thank you!

nvMelissa avatar Oct 16 '25 10:10 nvMelissa

Hello @Jianbing-D, thanks for your contribution! Would you please check why the PR is failing CI checks? If you don't find the reason, please request support from the reviewers. Thank you!

Hi @nvMelissa , Thank you for pointing out the CI failures. I have made some modifications to fix them.

Please review and trigger the CI.

Jianbing-D avatar Oct 17 '25 07:10 Jianbing-D

/te-ci

Oleg-Goncharov avatar Oct 17 '25 21:10 Oleg-Goncharov

/te-ci

Hi @Oleg-Goncharov , looks like there were issues on the CI system.

Installing collected packages: nvidia-mathdx, nvidia-cusparselt-cu12, mpmath, triton, sympy, pybind11-global, pybind11, nvidia-nvtx-cu12, nvidia-nvshmem-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, networkx, MarkupSafe, filelock, einops, onnx, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, jinja2, onnx_ir, nvidia-cusolver-cu12, torch, onnxscript
ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device

Jianbing-D avatar Oct 20 '25 01:10 Jianbing-D

Hi @Jianbing-D, could you please resolve the merge conflicts once more? Sorry for asking to do this again, but we refactored the cast kernels to organize them into clearer groups, so your MXFP8 kernel should now live under /common/cast/mxfp8/specialized/ with the appropriate kernel filtering. CI is fixed, so once the conflicts are resolved I can merge. Thanks a lot!

Oleg-Goncharov avatar Nov 03 '25 13:11 Oleg-Goncharov

Hi @Jianbing-D, could you please resolve the merge conflicts once more? Sorry for asking to do this again, but we refactored the cast kernels to organize them into clearer groups, so your MXFP8 kernel should now live under /common/cast/mxfp8/specialized/ with the appropriate kernel filtering. CI is fixed, so once the conflicts are resolved I can merge. Thanks a lot!

Hi @Oleg-Goncharov , I have refactored this branch, and put my kernel inside the folder you suggested. Please review it. Thanks.

Jianbing-D avatar Nov 18 '25 07:11 Jianbing-D

Greptile Summary

  • Adds optimized MXFP8 cast-only quantization kernels gated by ENABLE_CAST_ONLY environment variable, achieving 5-20% performance improvements for SM 10.0+ GPUs
  • Implements specialized kernels for rowwise and bidimensional scaling using TMA operations, memory swizzling, and new PTX intrinsics (mul_cvt_4x, fma_f32_f16/bf16)
  • Critical syntax errors in PTX inline assembly will cause compilation failures

Confidence Score: 2/5

  • This PR contains compilation-breaking syntax errors that must be fixed before merging
  • Score reflects critical syntax errors in ptx.cuh (missing commas in inline assembly at multiple locations) that will cause compilation failures. The architectural design and optimization approach are sound, but the code cannot compile in its current state.
  • transformer_engine/common/util/ptx.cuh requires immediate attention due to syntax errors that will prevent compilation

Important Files Changed

Filename Overview
transformer_engine/common/util/ptx.cuh Added new PTX intrinsics for SM 10.0+ including mul_cvt_4x, fma_f32_f16/bf16, and reduction operations; contains critical syntax errors with missing commas in inline assembly
transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh New file implementing optimized cast-only MXFP8 quantization kernels for rowwise and bidimensional scaling with TMA, swizzling, and environment-gated activation
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Integrated specialized cast-only kernels into existing quantization dispatcher with environment variable gating and automatic fallback to original kernels

Sequence Diagram

sequenceDiagram
    participant User
    participant Dispatcher as "quantize() dispatcher"
    participant SpecCheck as "specialized::hasSpec()"
    participant Kernel as "quantize_mxfp8_kernel_cast_only"
    participant PTX as "PTX intrinsics"
    
    User->>Dispatcher: Call quantize with input tensor
    Dispatcher->>SpecCheck: Check if specialized kernel available
    SpecCheck->>SpecCheck: Read ENABLE_CAST_ONLY env var
    alt Specialized kernel available
        SpecCheck-->>Dispatcher: Return true
        alt Rowwise scaling
            Dispatcher->>Dispatcher: Configure CastTraits<IType, OType, true, false>
            Dispatcher->>Kernel: Launch kernel (grid, block, smem)
            Kernel->>Kernel: Load input data chunks
            Kernel->>PTX: Compute amax via abs_max_2x
            Kernel->>PTX: Convert to e8m0 scale
            Kernel->>PTX: Scale and convert via mul_cvt_4x
            Kernel->>Kernel: Write FP8 output
            Kernel-->>Dispatcher: Complete
        else Bidimensional scaling
            Dispatcher->>Dispatcher: Configure CastTraits<IType, OType, true, true>
            Dispatcher->>Dispatcher: Create TMA tensor maps
            Dispatcher->>Kernel: Launch kernel with TMA maps
            Kernel->>Kernel: Load via TMA with swizzling
            Kernel->>PTX: Compute rowwise and colwise amax
            Kernel->>PTX: Scale and convert to FP8
            Kernel->>Kernel: Write rowwise and colwise outputs
            Kernel-->>Dispatcher: Complete
        end
        Dispatcher-->>User: Return quantized tensor
    else No specialized kernel
        SpecCheck-->>Dispatcher: Return false
        Dispatcher->>Dispatcher: Use original quantization path
        Dispatcher-->>User: Return quantized tensor
    end

greptile-apps[bot] avatar Nov 18 '25 07:11 greptile-apps[bot]

/te-ci

Oleg-Goncharov avatar Nov 19 '25 00:11 Oleg-Goncharov

Hi @Oleg-Goncharov

Seems CI encountered storage issue again. image

Jianbing-D avatar Nov 19 '25 07:11 Jianbing-D