cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] Any experience to help me optimize on a particular problem size(a small M)?

Open HuaYZhao opened this issue 2 years ago • 5 comments

I am a beginner to cutlass and I have reviewed many related documents and examples;I also have a general understanding of ThreadBlockShape, WarpShape, InstructionShape, NumStages; In my current problem domain, I want to deal with matrix multiplication with M, N, K of [8,4096,12288]; And M may vary between 2 and 8;In the flash decoding++ paper, this is called the flat gemm problem; With the above background, I have the following questions: 1、I have used cutlass profiler to search for the best parameters, but sometimes I manually change ThreadBlockShape, WarpShape, InstructionShape, resulting in compilation errors. I want to know what the constraints are on setting these parameters, and if there is documentation or a tutorial on how to tune them; Or is the combination of these parameters already mandatory in the definition of the header file? 2、The size of M in ThreadBlockShape is usually greater than or equal to 32. If my problem size M is 8, will the kernel fill with 0 during calculation? This seems to cause a certain waste. I am not sure how much this cost is. Is it possible and necessary to set M in ThreadBlockShape to 8 to get the best performance for this size problem, and if so, what should I do to make this adjustment?

In other words, I currently seem to understand the concept of these parameters, but it is not clear how these parameters affect the specific computation during kernel operation, including the allocation and utilization of resources. Is there any relevant documentation or tutorial available to learn, appreciate not only!

HuaYZhao avatar Jan 17 '24 13:01 HuaYZhao

i assume you are using fp16 on a100. if so, the instruction shape is always 16x8x16. you problem size M is super small. as you said, resources would be waste. you may have to use some irregular tile sizes that are not listed in the profiler. these irregular ones may or may not work. you need to try them.

you need to set threadblock M as small as possible, but bigger than your problem size M. 16 is minimum since it cannot be smaller than the instruction shape M, but I am not sure if it works in cutlass now. 32 is more likely working. thread block K is 32 or 64. warp block K is the same as thread block K. Normally, we use 4 warps per thread block. warp number = threadblock_M / warp_M x threadblock_N / warp_N. you can try to set warp M as the same as threadblock M or half of threadblock M and then choose warp N and threadblock N accordingly.

you have big problem size K. so to get good perf, you need to use splitk or streamk.

hwu36 avatar Jan 17 '24 19:01 hwu36

Thank you for your recovery, it helped me dispel part of the fog, I will try again; I do work on fp16 now;

assume you are using fp16 on a100. if so, the instruction shape is always 16x8x16.

Does that mean there's no way to change it? Or is it strongly not recommended to change it? In addition, how does this parameter work on other cards such as A10 or 3090, is there documentation available, or is it bound to SM?

In my current problem domain, I want to deal with matrix multiplication with M, N, K of [8,4096,12288];

I made a description error, I actually have a large N, the actual size of the problem should be M, N, K=[8, 12288, 4096];

you need to set threadblock M as small as possible, but bigger than your problem size M. 16 is minimum since it cannot be smaller than the instruction shape M, but I am not sure if it works in cutlass now. 32 is more likely working. thread block K is 32 or 64. warp block K is the same as thread block K. Normally, we use 4 warps per thread block. warp number = threadblock_M / warp_M x threadblock_N / warp_N. you can try to set warp M as the same as threadblock M or half of threadblock M and then choose warp N and threadblock N accordingly.

This is where I'm confused, I think you have a richer understanding to give this advice, is it purely empirical understanding, or are there some rules? In other words, I would like to have a primer on how to set these parameters so that when the size of my problem changes, say to [8, 4096, 4096] or [8, 27648, 5120], etc., I can make my own design based on some experience!

It's better to teach a man fishing than to give him fish. Thank you very much again!

HuaYZhao avatar Jan 18 '24 02:01 HuaYZhao

Does that mean there's no way to change it? Or is it strongly not recommended to change it? In addition, how does this parameter work on other cards such as A10 or 3090, is there documentation available, or is it bound to SM?

all sm80 should use this for fp16.

I made a description error, I actually have a large N, the actual size of the problem should be M, N, K=[8, 12288, 4096];

what i said earlier still holds. you have small M and large K.

This is where I'm confused, I think you have a richer understanding to give this advice, is it purely empirical understanding, or are there some rules? In other words, I would like to have a primer on how to set these parameters so that when the size of my problem changes, say to [8, 4096, 4096] or [8, 27648, 5120], etc., I can make my own design based on some experience!

What i described can apply to all small M and large K problems.

hwu36 avatar Jan 18 '24 02:01 hwu36

I used stream_k mode for my calculations and found that executing device_gemm() was actually relatively faster, but the total time taken by these lines of code in this mode was not negligible; In my case, the calculation time was roughly 6.5088e-02 ms, but the workspace related time was 4.50528e-01 ms,This is especially true in stream k mode

    size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    DeviceKernel device_gemm;
    cutlass::Status status = device_gemm.initialize(arguments,
                                                workspace.get());

I am currently wrapping a matrix multiplication operator, adjusting the layout for my matrix size, and finally exposing it to python as pybind; Whether to open the workspace every time matrix multiplication is called seems to cause a lot of performance consumption, and the overall time is very long to observe on the python side. Are there any optimizations?

I got good performance under the example 47_ampere_gemm_universal_streamk, but couldn't encapsulate it as an interface while retaining its high performance because of this workspace!

Here is my complete code

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/device_memory.h"
#include "helper.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

// A matrix configuration
using ElementA = cutlass::half_t;                                       // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor;                              // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using ElementB = cutlass::half_t;                                       // Element type for B matrix operand
using LayoutB = cutlass::layout::RowMajor;                              // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

// C/D matrix configuration
using ElementC = cutlass::half_t;                                       // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor;                              // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes)

// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = cutlass::half_t;                     // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm80;                            // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp;           // Operator class tag
using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>;         // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;   // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4;                                    // Number of global->shared pipeline stages used in the GEMM mainloop

// Epilogue output operator
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
    ElementC,            // Element type for C and D matrix operands
    AlignmentC,          // Memory access granularity of C and D matrix in units of elements
    ElementAccumulator,  // Element type from internal accumaccumulation
    ElementAccumulator>; // Data type used to compute linear combination

// Classic data-parallel device GEMM implementation type
using DeviceKernel = cutlass::gemm::device::GemmUniversal<
    ElementA, LayoutA,
    ElementB, LayoutB,
    ElementC, LayoutC,
    ElementAccumulator,
    OperatorClass,
    ArchTag,
    ThreadblockShape,
    WarpShape,
    InstructionShape,
    EpilogueOp,
    cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
    NumStages,
    AlignmentA,
    AlignmentB>;

using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;

cutlass::Status gemm_mod_kernel_run(int M, int N, int K,
                                    const DeviceKernel::ElementA *A, const DeviceKernel::ElementB *B, const DeviceKernel::ElementC *C, DeviceKernel::ElementC *D,
                                    ElementCompute alpha, ElementCompute beta)
{
    GpuTimer timer;
    timer.start();
    typename DeviceKernel::Arguments arguments{
        cutlass::gemm::GemmUniversalMode::kGemm,
        {M, N, K}, // problem size
        1,
        {alpha, beta},
        A,
        B,
        C,
        D,
        M * K,
        K * N,
        M * N,
        M * N,                                           // batch strides
        DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
        DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
        DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
        DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
        -1};

    size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    DeviceKernel gemm_op;
    cutlass::Status status = gemm_op.initialize(arguments,
                                                workspace.get());
    timer.stop();

    float elapsed_ms = timer.elapsed_millis();
    float avg_runtime_ms = double(elapsed_ms) / double(10000);
    std::cout << "  Avg runtime 1: " << avg_runtime_ms << " ms" << std::endl;

    timer.start();
    if (status != cutlass::Status::kSuccess)
    {
        return status;
    }

    status = gemm_op();

    // Run profiling loop
    // GpuTimer timer;
    // timer.start();
    // for (int iter = 0; iter < 10000; ++iter)
    // {
    //     CUTLASS_CHECK(gemm_op());
    // }
    timer.stop();

    elapsed_ms = timer.elapsed_millis();
    avg_runtime_ms = double(elapsed_ms) / double(10000);

    std::cout << "  Avg runtime 2: " << avg_runtime_ms << " ms" << std::endl;

    return status;
}

at::Tensor gemm_mod_kernel(const at::Tensor &A, const at::Tensor &B, at::optional<const at::Tensor> C, float alpha, float beta)
{
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC *ptrC = (C == at::nullopt) ? nullptr : reinterpret_cast<typename DeviceKernel::ElementC *>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kF16);

    cutlass::Status status = gemm_mod_kernel_run(M, N, K,
                                                 reinterpret_cast<typename DeviceKernel::ElementA *>(A.contiguous().data_ptr()),
                                                 reinterpret_cast<typename DeviceKernel::ElementB *>(B.contiguous().data_ptr()),
                                                 ptrC,
                                                 reinterpret_cast<typename DeviceKernel::ElementC *>(D.contiguous().data_ptr()),
                                                 ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}

HuaYZhao avatar Jan 30 '24 05:01 HuaYZhao

don't put workspace allocation in the critical path. you can manage your memory yourself or run this part in parallel with something else.

hwu36 avatar Jan 30 '24 15:01 hwu36

@HuaYZhao has your issue been resolved?

mnicely avatar Feb 22 '24 15:02 mnicely