cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] Cutlass Python API silently fails in (suspected) unsupported case

Open LucasWilkinson opened this issue 1 year ago • 4 comments

Describe the bug When using the cutlass python API to create a GEMM + EVT we found that on Ampere, namely an A100 (CUTLASS 2.x stuff), the GEMM can return garbage. Our suspicion is that we have encountered an unsupported case, but this is not properly being caught by the Python API. We would have expected it to error if we were doing something unsupported instead of failing silently by returning garbage. Happy to expand or help reproduce if this report is insufficient.

Steps/Code to reproduce bug The following code reproduces the bug (we used the 04_epilogue_visitor.ipynb example as a starting point). The plan_good GEMM plan shows a version that does not return garbage, and plan_bad is the GEMM plan that returns garbage. The only difference between the plans is that the B_layout is column-major in the good plan and row-major in the bad plan.


import torch
import cutlass
import copy
from cutlass.epilogue import relu
from cutlass import Tensor as FakeTensor
from cutlass.utils.profiler import CUDAEventProfiler

# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to
# omit this information.
print_module = False

# The Epilogue Visitor feature currently only works for SM80 and 90
from cutlass.backend.utils.device import device_cc
if device_cc() not in [80, 90]:
    import sys
    print(device_cc(), "not supported")
    sys.exit()

m = 16384
n = m
k = 512

type_A = torch.int8
type_B = torch.int8
type_C = torch.int32
type_D = torch.int32

torch.manual_seed(2023)
scope_min = -4
scope_max = 4
tensor_A = torch.randint(low=scope_min, high=scope_max, size=(m, k), dtype=type_A, device="cuda")
tensor_B_row_major = torch.randint(low=scope_min, high=scope_max, size=(k, n), dtype=type_B, device="cuda")
tensor_B_col_major = tensor_B_row_major.t().contiguous().t()
tensor_C = torch.zeros(size=(m, n), dtype=type_C, device="cuda")
tensor_D = torch.zeros_like(tensor_C)

plan_args_good = dict(
    element_A=torch.int8, element_B=torch.int8, element_C=torch.int32, element_D=torch.int32, 
    layout_A=cutlass.LayoutType.RowMajor, 
    layout_B=cutlass.LayoutType.ColumnMajor,
    layout_C=cutlass.LayoutType.RowMajor, 
    element_accumulator=torch.int32, kernel_cc=80
)

plan_args_bad = copy.copy(plan_args_good)
plan_args_bad["layout_B"] = cutlass.LayoutType.RowMajor

plan_good = cutlass.op.Gemm(**plan_args_good)
plan_bad  = cutlass.op.Gemm(**plan_args_bad)

# Define epilogue visitor
def example_epilogue(accum, C):
    D = accum + C
    return D

# Construct inputs and outputs
examples_tensors = {
    "accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
    "D": tensor_D,
    "C": tensor_C,
}

# Trace the epilogue visitor
epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)

visitor_args = {
    "D": tensor_D,
    "C": tensor_C,
}

class TorchReference(torch.nn.Module):
    def forward(self, A, B, C):
        accum = torch.matmul(A.to(dtype=torch.float32), B.to(dtype=torch.float32))
        return example_epilogue(accum.to(dtype=torch.float32), C.to(dtype=torch.float32))

torch_reference = TorchReference()
tensor_D_ref = torch_reference(tensor_A, tensor_B_row_major, tensor_C)


plan_good.epilogue_visitor = epilogue_visitor
plan_good.run(
    tensor_A, tensor_B_col_major, tensor_C, tensor_D, 
    visitor_args=visitor_args, print_module=print_module)

print("Plan good result: ", torch.allclose(tensor_D.to(dtype=torch.float32), tensor_D_ref, rtol=1e-2))

plan_bad.epilogue_visitor = epilogue_visitor
plan_bad.run(
    tensor_A, tensor_B_col_major, tensor_C, tensor_D, 
    visitor_args=visitor_args, print_module=print_module)

print("Plan bad result: ", torch.allclose(tensor_D.to(dtype=torch.float32), tensor_D_ref, rtol=1e-2))
print("tensor_D:     \n", tensor_D.to(dtype=torch.float32))
print("tensor_D_ref: \n", tensor_D_ref)

Output:

Plan good result:  True
Plan bad result:  False
tensor_D:     
 tensor([[ 35.,  86., 147.,  ..., -23., 168., 126.],
        [ 82.,  21., -55.,  ..., 247., 207., 179.],
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
        ...,
        [ 39., 199.,  52.,  ..., 174., -82., 141.],
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.]], device='cuda:0')
tensor_D_ref: 
 tensor([[ 194.,  -81.,   45.,  ...,   78.,  233.,  106.],
        [ 132.,  189.,  194.,  ...,  183.,  -78.,  188.],
        [ 107.,  141.,  201.,  ...,   54.,  332.,  223.],
        ...,
        [ 178.,  284.,  -96.,  ...,  -11.,  238.,  433.],
        [ 205.,  119.,  139.,  ...,  -10.,  329.,   17.],
        [ 338.,  -24.,  189.,  ..., -196.,  105.,  457.]], device='cuda:0')

Expected behavior

For the bad plan, we would have expected the Python code to error out and tell us that the combination of GEMM and EVT we specified is unsupported instead of returning garbage.

Environment details (please complete the following information):

Bare-metal, A100, installed cutlass via pip install nvidia-cutlass

Additional context Looking at the kernels generated, we can see that in the good plan:

// Gemm operator cutlass_tensorop_i16832gemm_s8_256x128_128x3_tn_align16
using cutlass_tensorop_i16832gemm_s8_256x128_128x3_tn_align16_base =
    ...
    cutlass::arch::OpClassTensorOp,
    ...
>::GemmKernel;

it uses cutlass::arch::OpClassTensorOp however in the bad plan:

// Gemm operator cutlass_simt_igemm_s8_128x128_8x2_tt_align1
using cutlass_simt_igemm_s8_128x128_8x2_tt_align1_base =
    ...
    cutlass::arch::OpClassSimt,
    ...
>::GemmKernel;

it falls back to cutlass::arch::OpClassSimt.

We suspect that the OutputTileThreadLayout used as an adapter to allow EVT epilogues to be attached to 2.x GEMMs only supports cutlass::arch::OpClassTensorOp (we do see similar failures when using cutlass::arch::OpClassWmmaTensorOp). However, we have failed to find any documentation or asserts to confirm this.

LucasWilkinson avatar May 23 '24 22:05 LucasWilkinson

@apuaachen, can you take a look?

jackkosaian avatar May 30 '24 16:05 jackkosaian

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jun 29 '24 17:06 github-actions[bot]

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Sep 27 '24 17:09 github-actions[bot]

Still broken on:

nvidia-cutlass==3.5.1.0

LucasWilkinson avatar Oct 01 '24 15:10 LucasWilkinson

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

github-actions[bot] avatar Dec 30 '24 16:12 github-actions[bot]