[QST] GEMM Epilogue Fusion: Row-wise and Column-wise Multiplication
What is your question?
Hi, I'd like to compute the following
// inputs
// A [M, K] int8
// B [N, K] int4
// alphaCol [M, 1] fp32
// alphaRow [1, N] fp32
// outputs
// mat [M, N] fp32
Mathematically equivalent to: (A x B) * (alphaCol x alphaRow)
Based on this PR, I have implemented the following basic s4/s8 GEMM. Now I can get the correct resultC = A x B.
/*
s4/s8 GEMM
*/
torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B) {
torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device()));
using ElementOutput = int32_t;
using LayoutC = cutlass::layout::RowMajor;
using ElementAccumulator = int32_t;
using ElementA = int8_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::int4b_t;
using LayoutB = cutlass::layout::ColumnMajor;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA, // ElementA
cutlass::layout::RowMajor, // LayoutA
ElementB, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
ElementOutput, // ElementOutput
cutlass::layout::RowMajor, // LayoutOutput
ElementAccumulator, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
cutlass::arch::Sm80, // tag indicating target GPU compute architecture
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 32>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
32, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
using GemmCoord = cutlass::gemm::GemmCoord;
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count = 1;
cutlass::gemm::GemmCoord problem_size(M, N, K);
int64_t stride_A = M * K;
int64_t stride_B = N * K;
int64_t stride_C = M * N;
int64_t stride_D = M * N;
typename Gemm::Arguments arguments{
mode,
{static_cast<GemmCoord::Index>(M), static_cast<GemmCoord::Index>(N),
static_cast<GemmCoord::Index>(K)},
batch_count,
{1, 0},
A.data_ptr<int8_t>(),
(cutlass::int4b_t *)B.data_ptr<uint8_t>(),
C.data_ptr<int32_t>(),
C.data_ptr<int32_t>(),
stride_A,
stride_B,
stride_C,
stride_D,
K,
K,
N,
N
};
Now I want to add alphaCol and alphaRow as epilogue to get the final outputs(M,N), what are the best ways to implement this?
I'm trying EVT, but I don’t know how to start with it.
If it's possible pointers in the correct direction would be greatly appreciated, thanks!
similar issue:
- https://github.com/NVIDIA/cutlass/issues/783
- https://github.com/NVIDIA/cutlass/issues/785
Hi @apuaaChen ,thanks for your work about EVT. I have completed my GEMM and EVT construction, but myEVTD::Arguments callback_args parameter seems to be incorrect and I got some errors. Could you please take some time to help me check it? Thanks!
torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1); // 4bit packing is on the columns
auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device()));
// A matrix configuration
using ElementA = int8_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::int4b_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = float; //cutlass::half_t; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = float; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = int32_t; // Element type for internal accumulation
using ElementCompute = float; //cutlass::half_t; // Element type for compute
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1;
// StreamK device GEMM implementation type with EVT
using namespace cute;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
EVTEpilogueStages
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
// alphaCol [M, 1] fp32
using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<int32_t, _1, _0> // StrideMNL
>;
// alphaRow [1, N] fp32
using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
// mul
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaCol * accumulator
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Accum,
V1Broadcast>;
// mul
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaRow * alphaCol * accumulator
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
Compute1,
EVTCompute0,
V2Broadcast>;
using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
StoreD,
EVTCompute1>;
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel;
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
// Ensure the input tensors are in the correct device and layout
auto tensor_a = A.contiguous();
auto tensor_b = B.contiguous();
auto tensor_v1 = alphaCol.contiguous();
auto tensor_v2 = alphaRow.contiguous();
auto tensor_d = D.contiguous(); // EVTD
typename EVTD::Arguments callback_args{
{
{
{
{}, // Accum
{tensor_v1.data_ptr<ElementC>(), ElementC(0), {int32_t(M), _1{}, _0{}}}, // V1 Broadcast
{} // Compute0
}, // EVTCompute0
{tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}}, // V2 Broadcast
{} // Compute1
}, // EVTCompute1
{} // Compute2
}, // EVTCompute2
{tensor_d.data_ptr<ElementC>(), {int32_t{N}, _1{}, int32_t{M*N}}} // D
}; // EVTD
using GemmCoord = cutlass::gemm::GemmCoord;
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count = 1;
// Construct Gemm ProblemSize with user defined output size
cutlass::gemm::GemmCoord problem_size(M, N, K);
int64_t stride_A = M * K;
int64_t stride_B = N * K;
// EVTD
int avail_sms = -1;
typename DeviceGemmStreamK::Arguments arguments(
mode, // universal mode
problem_size, // problem_size
batch_count, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_a.data_ptr<ElementA>(), // ptr_A
(cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
stride_A, // batch_stride_A
stride_B, // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_a.stride(0), // stride_a
tensor_b.stride(0), // stride_b
0, // stride_c (unused)
0, // stride_d (unused)
avail_sms); // avail_sms
DeviceGemmStreamK gemm_op;
error message:
error: no instance of constructor "cute::tuple<T...>::tuple [with T=<>]" matches the argument list argument types are: ({...}, {...}, {...}) /target/w4a8/cutlass/include/cute/container/tuple.hpp(147): error: no instance of constructor "cute::detail::EBO<N, T, true>::EBO [with N=2UL, T=cute::C<0>]" matches the argument list argument types are: (const int32_t) detected during: instantiation of "cute::detail::TupleBase<cute::index_sequence<I...>, T...>::TupleBase(const U &...) [with I=<0UL, 1UL, 2UL>, T=<int32_t, cute::_1, cute::_0>, U=<cute::_0, cute::_1, int32_t>]"
Hi @Hongbosherlock! Thank you for your patient. I have attached the revision that should work
torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1); // 4bit packing is on the columns
auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device()));
// A matrix configuration
using ElementA = int8_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::int4b_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C1/C2/D matrix configuration
using ElementC = float; //cutlass::half_t; // Element type for C matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)
// Output matrix configuration
using ElementOutput = float; // Element type for output matrix operands
using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands
// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = int32_t; // Element type for internal accumulation
using ElementCompute = float; //cutlass::half_t; // Element type for compute
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1;
// StreamK device GEMM implementation type with EVT
using namespace cute;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
EVTEpilogueStages
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
// alphaCol [M, 1] fp32
using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_1, _0, int32_t> // StrideMN
>;
// alphaRow [1, N] fp32
using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
// mul
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaCol * accumulator
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Accum,
V1Broadcast>;
// mul
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alphaRow * alphaCol * accumulator
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
Compute1,
EVTCompute0,
V2Broadcast>;
using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
StoreD,
EVTCompute1>;
using EVTKernelStreamK =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, LayoutC, AlignmentC,
ElementAccumulator,
ElementCompute,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages
>::GemmKernel;
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;
// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options
// Ensure the input tensors are in the correct device and layout
auto tensor_a = A.contiguous();
auto tensor_b = B.contiguous();
auto tensor_v1 = alphaCol.contiguous();
auto tensor_v2 = alphaRow.contiguous();
auto tensor_d = D.contiguous();
typename EVTD::Arguments callback_args{ // EVTD
{ // EVTCompute1
{ // EVTCompute0
{}, // Accum
{tensor_v1.data_ptr<ElementC>(), ElementC(0), {_1{},_0{},int32_t(M)}}, // V1 Broadcast
{} // Compute0
}, // EVTCompute0
{tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}}, // V2 Broadcast
{} // Compute1
}, // EVTCompute1
{tensor_d.data_ptr<ElementC>(), {int64_t{N}, _1{}, int64_t{M*N}}} // D
}; // EVTD
using GemmCoord = cutlass::gemm::GemmCoord;
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
int batch_count = 1;
// Construct Gemm ProblemSize with user defined output size
cutlass::gemm::GemmCoord problem_size(M, N, K);
int64_t stride_A = M * K;
int64_t stride_B = N * K;
// EVTD
int avail_sms = -1;
typename DeviceGemmStreamK::Arguments arguments(
mode, // universal mode
problem_size, // problem_size
batch_count, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_a.data_ptr<ElementA>(), // ptr_A
(cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
stride_A, // batch_stride_A
stride_B, // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_a.stride(0), // stride_a
tensor_b.stride(0), // stride_b
0, // stride_c (unused)
0, // stride_d (unused)
avail_sms); // avail_sms
DeviceGemmStreamK gemm_op;
There are a few changes
- The stride of V1Broadcast should be
cute::Stride<_1, _0, int32_t>, as it is in order M, N, L. - The stride of the d should be in type int64_t rather than int32_t, as defined in
StoreD. - The error message is triggered by the Compute2 in your original callbacks construction, which does not exist in your EVT definition.
Please let me know if you have more questions or the script doesn't work. Thanks!