cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] GEMM Epilogue Fusion: Row-wise and Column-wise Multiplication

Open Hongbosherlock opened this issue 1 year ago • 2 comments

What is your question?

Hi, I'd like to compute the following

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int4
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32

    // outputs
    //     mat [M, N]            fp32

Mathematically equivalent to: (A x B) * (alphaCol x alphaRow) Based on this PR, I have implemented the following basic s4/s8 GEMM. Now I can get the correct resultC = A x B.

/*
s4/s8 GEMM
*/
torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B) {
    torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
    auto M = A.size(0);
    auto N = B.size(0);
    auto K = A.size(1);  
    auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device()));

    using ElementOutput = int32_t;
    using LayoutC  = cutlass::layout::RowMajor;   
    using ElementAccumulator = int32_t;
    using ElementA = int8_t;
    using LayoutA  = cutlass::layout::RowMajor;   
    using ElementB = cutlass::int4b_t;
    using LayoutB  = cutlass::layout::ColumnMajor;   

    using Gemm = cutlass::gemm::device::GemmUniversal<
        ElementA,                // ElementA
        cutlass::layout::RowMajor,       // LayoutA
        ElementB,                // ElementB
        cutlass::layout::ColumnMajor,    // LayoutB
        ElementOutput,                         // ElementOutput
        cutlass::layout::RowMajor,       // LayoutOutput
        ElementAccumulator,                         // ElementAccumulator
        cutlass::arch::OpClassTensorOp,  // tag indicating Tensor Cores
        cutlass::arch::Sm80,  // tag indicating target GPU compute architecture
        cutlass::gemm::GemmShape<128, 128, 64>,
        cutlass::gemm::GemmShape<64, 64, 64>,
        cutlass::gemm::GemmShape<16, 8, 32>,
        cutlass::epilogue::thread::LinearCombination<
            ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
            ElementAccumulator, ElementAccumulator>,
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
        4,   // Stages
        16,  // AlignmentA
        32,  // AlignmentB
        cutlass::arch::OpMultiplyAddMixedInputUpcast,
        cutlass::ComplexTransform::kNone,
        cutlass::ComplexTransform::kNone
        >;

    using GemmCoord = cutlass::gemm::GemmCoord;

    cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
    int batch_count = 1;
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    int64_t stride_A = M * K;
    int64_t stride_B = N * K;
    int64_t stride_C = M * N;
    int64_t stride_D = M * N;

    typename Gemm::Arguments arguments{
        mode,
        {static_cast<GemmCoord::Index>(M), static_cast<GemmCoord::Index>(N),
       static_cast<GemmCoord::Index>(K)},
        batch_count,
        {1, 0},
        A.data_ptr<int8_t>(),
        (cutlass::int4b_t *)B.data_ptr<uint8_t>(),
        C.data_ptr<int32_t>(),
        C.data_ptr<int32_t>(),
        stride_A,
        stride_B,
        stride_C,
        stride_D,
        K,
        K,
        N,  
        N
        };

Now I want to add alphaCol and alphaRow as epilogue to get the final outputs(M,N), what are the best ways to implement this? I'm trying EVT, but I don’t know how to start with it. If it's possible pointers in the correct direction would be greatly appreciated, thanks!

similar issue:

  • https://github.com/NVIDIA/cutlass/issues/783
  • https://github.com/NVIDIA/cutlass/issues/785

Hongbosherlock avatar May 31 '24 03:05 Hongbosherlock

Hi @apuaaChen ,thanks for your work about EVT. I have completed my GEMM and EVT construction, but myEVTD::Arguments callback_args parameter seems to be incorrect and I got some errors. Could you please take some time to help me check it? Thanks!

torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
    torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
    auto M = A.size(0);
    auto N = B.size(0);
    auto K = A.size(1);  // 4bit packing is on the columns
    auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device()));

    // A matrix configuration
    using         ElementA         = int8_t;                                    // Element type for A matrix operand
    using         LayoutA          = cutlass::layout::RowMajor;                        // Layout type for A matrix operand
    constexpr int AlignmentA       = 128 / cutlass::sizeof_bits<ElementA>::value;      // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

    // B matrix configuration
    using         ElementB         = cutlass::int4b_t;                                  // Element type for B matrix operand
    using         LayoutB          = cutlass::layout::ColumnMajor;                        // Layout type for B matrix operand
    constexpr int AlignmentB       = 128 / cutlass::sizeof_bits<ElementB>::value;      // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

    // C1/C2/D matrix configuration
    using         ElementC         = float;        //cutlass::half_t;                   // Element type for C matrix operands
    using         LayoutC          = cutlass::layout::RowMajor;                        // Layout type for C matrix operands
    constexpr int AlignmentC       = 128 / cutlass::sizeof_bits<ElementC>::value;      // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)

    // Output matrix configuration
    using         ElementOutput    = float;                                          // Element type for output matrix operands
    using         LayoutOutput     = cutlass::layout::RowMajor;                        // Layout type for output matrix operands
    // constexpr int AlignmentOutput  = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)

    // Multiply-accumulate blocking/pipelining details
    using ElementAccumulator  = int32_t;                                 // Element type for internal accumulation
    using ElementCompute      = float;  //cutlass::half_t;                          // Element type for compute
    using ArchTag             = cutlass::arch::Sm80;                      // Tag indicating the minimum SM that supports the intended feature
    using OperatorClass       = cutlass::arch::OpClassTensorOp;           // Operator class tag
    using ThreadblockShape    = cutlass::gemm::GemmShape<128, 128, 64>;   // Threadblock-level tile size (concept: GemmShape)
    using WarpShape           = cutlass::gemm::GemmShape<64, 64, 64>;     // Warp-level tile size (concept: GemmShape)
    using InstructionShape    = cutlass::gemm::GemmShape<16, 8, 32>;      // Instruction-level tile size (concept: GemmShape)
    constexpr int NumStages   = 4;                                        // Number of global->shared pipeline stages used in the GEMM mainloop
    constexpr int EVTEpilogueStages = 1;   

    // StreamK device GEMM implementation type with EVT
    using namespace cute;

    using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    ThreadblockShape, 
    WarpShape, 
    ElementC, 
    AlignmentC, 
    EVTEpilogueStages
    >;

    using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

    // alphaCol    [M, 1]    fp32
    using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<int32_t, _1, _0>  // StrideMNL
    >;

    // alphaRow    [1, N]    fp32
    using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<_0, _1, int32_t>  // StrideMNL
    >;

    // mul
    using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementCompute, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaCol * accumulator
    using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute0,
        Accum,
        V1Broadcast>;

    // mul
    using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementCompute, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaRow * alphaCol * accumulator
    using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute1,
        EVTCompute0,
        V2Broadcast>;

    using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
        OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
        cute::Stride<int64_t, _1, int64_t> // StrideMNL
    >;

    using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
        StoreD,
        EVTCompute1>;

    using EVTKernelStreamK =
        typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
        ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
        ElementC, LayoutC, AlignmentC,
        ElementAccumulator,
        ElementCompute,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages
    >::GemmKernel;

    using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;

    // Populates a DeviceGemmStreamK::Arguments structure from the given commandline options

    // Ensure the input tensors are in the correct device and layout
    auto tensor_a = A.contiguous();
    auto tensor_b = B.contiguous();
    auto tensor_v1 = alphaCol.contiguous();
    auto tensor_v2 = alphaRow.contiguous();
    auto tensor_d = D.contiguous();                                                                                              // EVTD

    typename EVTD::Arguments callback_args{
        {
            {
                {
                    {},                                                                                                          // Accum
                    {tensor_v1.data_ptr<ElementC>(), ElementC(0), {int32_t(M), _1{}, _0{}}},                                    // V1 Broadcast
                    {}                                                                                                           // Compute0
                },                                                                                                             // EVTCompute0
                {tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}},                                      // V2 Broadcast
                {}                                                                                                             // Compute1
            },                                                                                                               // EVTCompute1
            {}                                                                                                               // Compute2
        },                                                                                                                // EVTCompute2
        {tensor_d.data_ptr<ElementC>(), {int32_t{N}, _1{}, int32_t{M*N}}}                                                                  // D
    };                                                                                                                   // EVTD
    
                  
    using GemmCoord = cutlass::gemm::GemmCoord;
    cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
    int batch_count = 1;
    // Construct Gemm ProblemSize with user defined output size
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    int64_t stride_A = M * K;
    int64_t stride_B = N * K;
                                                                                           // EVTD
    int     avail_sms = -1;
    typename DeviceGemmStreamK::Arguments arguments(
        mode,                                     // universal mode
        problem_size,                             // problem_size
        batch_count,                              // batch count / splitk slices
        callback_args,                            // argument of EVT callbacks
        tensor_a.data_ptr<ElementA>(),            // ptr_A
        (cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(),            // ptr_B
        nullptr,                                  // ptr_C (unused)
        nullptr,                                  // ptr_D (unused)
        stride_A,                                 // batch_stride_A
        stride_B,                                 // batch_stride_B
        0,                                        // batch_stride_C (unused)
        0,                                        // batch_stride_D (unused)
        tensor_a.stride(0),                       // stride_a
        tensor_b.stride(0),                       // stride_b
        0,                                        // stride_c (unused)
        0,                                        // stride_d (unused)
        avail_sms);                               // avail_sms
    

    DeviceGemmStreamK gemm_op;

error message:

error: no instance of constructor "cute::tuple<T...>::tuple [with T=<>]" matches the argument list argument types are: ({...}, {...}, {...}) /target/w4a8/cutlass/include/cute/container/tuple.hpp(147): error: no instance of constructor "cute::detail::EBO<N, T, true>::EBO [with N=2UL, T=cute::C<0>]" matches the argument list argument types are: (const int32_t) detected during: instantiation of "cute::detail::TupleBase<cute::index_sequence<I...>, T...>::TupleBase(const U &...) [with I=<0UL, 1UL, 2UL>, T=<int32_t, cute::_1, cute::_0>, U=<cute::_0, cute::_1, int32_t>]"

Hongbosherlock avatar Jun 04 '24 02:06 Hongbosherlock

Hi @Hongbosherlock! Thank you for your patient. I have attached the revision that should work

torch::Tensor matmul_w4a8(const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &alphaCol, const torch::Tensor &alphaRow) {
    torch::checkAllSameGPU("W4A8Matmul", {{A, "A", 0}, {B, "B", 1}});
    auto M = A.size(0);
    auto N = B.size(0);
    auto K = A.size(1);  // 4bit packing is on the columns
    auto D = torch::empty({M, N}, torch::dtype(torch::kFloat32).device(A.device()));

    // A matrix configuration
    using         ElementA         = int8_t;                                    // Element type for A matrix operand
    using         LayoutA          = cutlass::layout::RowMajor;                        // Layout type for A matrix operand
    constexpr int AlignmentA       = 128 / cutlass::sizeof_bits<ElementA>::value;      // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

    // B matrix configuration
    using         ElementB         = cutlass::int4b_t;                                  // Element type for B matrix operand
    using         LayoutB          = cutlass::layout::ColumnMajor;                        // Layout type for B matrix operand
    constexpr int AlignmentB       = 128 / cutlass::sizeof_bits<ElementB>::value;      // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

    // C1/C2/D matrix configuration
    using         ElementC         = float;        //cutlass::half_t;                   // Element type for C matrix operands
    using         LayoutC          = cutlass::layout::RowMajor;                        // Layout type for C matrix operands
    constexpr int AlignmentC       = 128 / cutlass::sizeof_bits<ElementC>::value;      // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes)

    // Output matrix configuration
    using         ElementOutput    = float;                                          // Element type for output matrix operands
    using         LayoutOutput     = cutlass::layout::RowMajor;                        // Layout type for output matrix operands
    // constexpr int AlignmentOutput  = 128 / cutlass::sizeof_bits<ElementOutput>::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes)

    // Multiply-accumulate blocking/pipelining details
    using ElementAccumulator  = int32_t;                                 // Element type for internal accumulation
    using ElementCompute      = float;  //cutlass::half_t;                          // Element type for compute
    using ArchTag             = cutlass::arch::Sm80;                      // Tag indicating the minimum SM that supports the intended feature
    using OperatorClass       = cutlass::arch::OpClassTensorOp;           // Operator class tag
    using ThreadblockShape    = cutlass::gemm::GemmShape<128, 128, 64>;   // Threadblock-level tile size (concept: GemmShape)
    using WarpShape           = cutlass::gemm::GemmShape<64, 64, 64>;     // Warp-level tile size (concept: GemmShape)
    using InstructionShape    = cutlass::gemm::GemmShape<16, 8, 32>;      // Instruction-level tile size (concept: GemmShape)
    constexpr int NumStages   = 4;                                        // Number of global->shared pipeline stages used in the GEMM mainloop
    constexpr int EVTEpilogueStages = 1;   

    // StreamK device GEMM implementation type with EVT
    using namespace cute;

    using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    ThreadblockShape, 
    WarpShape, 
    ElementC, 
    AlignmentC, 
    EVTEpilogueStages
    >;

    using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

    // alphaCol    [M, 1]    fp32
    using V1Broadcast = cutlass::epilogue::threadblock::VisitorColBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<_1, _0, int32_t>  // StrideMN
    >;

    // alphaRow    [1, N]    fp32
    using V2Broadcast = cutlass::epilogue::threadblock::VisitorRowBroadcast<
        OutputTileThreadMap, ElementC,
        cute::Stride<_0, _1, int32_t>  // StrideMNL
    >;

    // mul
    using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementCompute, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaCol * accumulator
    using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute0,
        Accum,
        V1Broadcast>;

    // mul
    using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
        cutlass::multiplies, ElementCompute, ElementCompute,
        cutlass::FloatRoundStyle::round_to_nearest
    >;

    // alphaRow * alphaCol * accumulator
    using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<
        Compute1,
        EVTCompute0,
        V2Broadcast>;

    using StoreD = cutlass::epilogue::threadblock::VisitorAuxStore<
        OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
        cute::Stride<int64_t, _1, int64_t> // StrideMNL
    >;

    using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
        StoreD,
        EVTCompute1>;

    using EVTKernelStreamK =
        typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA,
        ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB,
        ElementC, LayoutC, AlignmentC,
        ElementAccumulator,
        ElementCompute,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages
    >::GemmKernel;

    using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter<EVTKernelStreamK>;

    // Populates a DeviceGemmStreamK::Arguments structure from the given commandline options

    // Ensure the input tensors are in the correct device and layout
    auto tensor_a = A.contiguous();
    auto tensor_b = B.contiguous();
    auto tensor_v1 = alphaCol.contiguous();
    auto tensor_v2 = alphaRow.contiguous();
    auto tensor_d = D.contiguous(); 
    
    
   typename EVTD::Arguments callback_args{  // EVTD
    { // EVTCompute1
      { // EVTCompute0
        {}, // Accum
        {tensor_v1.data_ptr<ElementC>(), ElementC(0), {_1{},_0{},int32_t(M)}}, // V1 Broadcast
        {}  // Compute0
      },  // EVTCompute0
      {tensor_v2.data_ptr<ElementC>(), ElementC(0), {_0{}, _1{}, int32_t(N)}}, // V2 Broadcast
      {} // Compute1                                                                                                             
    }, // EVTCompute1
    {tensor_d.data_ptr<ElementC>(), {int64_t{N}, _1{}, int64_t{M*N}}}  // D
  };  // EVTD
    
                  
    using GemmCoord = cutlass::gemm::GemmCoord;
    cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm;
    int batch_count = 1;
    // Construct Gemm ProblemSize with user defined output size
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    int64_t stride_A = M * K;
    int64_t stride_B = N * K;
                                                                                           // EVTD
    int     avail_sms = -1;
    typename DeviceGemmStreamK::Arguments arguments(
        mode,                                     // universal mode
        problem_size,                             // problem_size
        batch_count,                              // batch count / splitk slices
        callback_args,                            // argument of EVT callbacks
        tensor_a.data_ptr<ElementA>(),            // ptr_A
        (cutlass::int4b_t *)tensor_b.data_ptr<uint8_t>(),            // ptr_B
        nullptr,                                  // ptr_C (unused)
        nullptr,                                  // ptr_D (unused)
        stride_A,                                 // batch_stride_A
        stride_B,                                 // batch_stride_B
        0,                                        // batch_stride_C (unused)
        0,                                        // batch_stride_D (unused)
        tensor_a.stride(0),                       // stride_a
        tensor_b.stride(0),                       // stride_b
        0,                                        // stride_c (unused)
        0,                                        // stride_d (unused)
        avail_sms);                               // avail_sms
    

    DeviceGemmStreamK gemm_op;

There are a few changes

  1. The stride of V1Broadcast should be cute::Stride<_1, _0, int32_t>, as it is in order M, N, L.
  2. The stride of the d should be in type int64_t rather than int32_t, as defined in StoreD.
  3. The error message is triggered by the Compute2 in your original callbacks construction, which does not exist in your EVT definition.

Please let me know if you have more questions or the script doesn't work. Thanks!

apuaaChen avatar Jun 21 '24 20:06 apuaaChen