[BUG] e4m3, int8, bf16 pytorch emitter not working
I am attempting to emit pytorch code but unfortunately it does not work for fp8, bf16, and int8. I have tried to patch the converter type dict https://github.com/OrenLeung/cutlass/commit/6d619c964eb8b9c150a5f97891849d33f6ee8b64
This patch fixed the initial set of issue but unfortunately I run into a deeper issue when the emitter tries to compile the kernel.
New Error Issue
c /workspace/cutlass/examples/python/gemm_mod_fp8_kernel.cu -o gemm_mod_fp8_kernel.cuda.o
/workspace/cutlass/examples/python/gemm_mod_fp8_kernel.cu(100): error: a value of type "cutlass::float_e4m3_t *" cannot be used to initialize an entity of type "const float *"
D,
^
Reprod
import torch
import cutlass
# FP8 CUTLASS GEMM plan
plan_fp8 = cutlass.op.Gemm(
element=torch.float8_e4m3fn,
element_accumulator=torch.float32,
element_D=torch.float32,
layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.ColumnMajor,
layout_C=cutlass.LayoutType.ColumnMajor)
op_fp8 = plan_fp8.construct()
# Generate the PyTorch module for the FP8 GEMM operation
mod_fp8 = cutlass.emit.pytorch(op_fp8, name='gemm_mod_fp8', cc=plan_fp8.cc, jit=True)
# BF16 CUTLASS GEMM plan
plan_bf16 = cutlass.op.Gemm(
element=torch.bfloat16,
element_accumulator=torch.float32,
element_D=torch.float32,
layout_A=cutlass.LayoutType.RowMajor,
layout_B=cutlass.LayoutType.ColumnMajor,
layout_C=cutlass.LayoutType.ColumnMajor)
op_bf16 = plan_bf16.construct()
# Generate the PyTorch module for the BF16 GEMM operation
mod_bf16 = cutlass.emit.pytorch(op_bf16, name='gemm_mod_bf16', cc=plan_bf16.cc, jit=True)
plan = cutlass.op.Gemm(
element=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32,
element_D=cutlass.DataType.s32,
layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
# Generate the PyTorch module for the GEMM operation
mod = cutlass.emit.pytorch(op, name='gemm_mod', cc=plan.cc, jit=True)
plan = cutlass.op.Gemm(
element=cutlass.DataType.s8,
element_accumulator=cutlass.DataType.s32,
element_D=cutlass.DataType.s32,
layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
# Generate the PyTorch module for the GEMM operation
mod = cutlass.emit.pytorch(op, name='gemm_mod', cc=plan.cc, jit=True)
@jackkosaian
Thanks for reporting this. The issue seems to be that are currently using ElementC in place of ElementD when emitting CUTLASS 3.x GEMMs.
Can you try out the following diff?
diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py
index 87025eea..20eac2d4 100644
--- a/python/cutlass/emit/common.py
+++ b/python/cutlass/emit/common.py
@@ -112,7 +112,7 @@ using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status ${name}_kernel_run(
int M, int N, int K, int L,
- const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
+ const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementD* D,
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
typename DeviceKernel::Arguments arguments{
diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py
index ac13e866..0f659671 100644
--- a/python/cutlass/emit/pytorch.py
+++ b/python/cutlass/emit/pytorch.py
@@ -310,7 +310,7 @@ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
ptrC,
- reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
+ reinterpret_cast<typename DeviceKernel::ElementD*>(D.contiguous().data_ptr()),
ElementCompute(alpha), ElementCompute(beta),
hw_info);
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.