[BUG] Trying to optimize mixed input for kernels
Describe the bug I was reading through the cutlass mixed precision kernels, https://github.com/NVIDIA/cutlass/blob/cc3c29a81a140f7b97045718fb88eb0664c37bd7/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp#L552. Written this way, it implicitly checks that group size > tile shape K, which should not matter since we account for this by a reload factor later in the code: https://github.com/NVIDIA/cutlass/blob/cc3c29a81a140f7b97045718fb88eb0664c37bd7/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp#L765.
Should this be reversed? That is, the tile shape K > group size. This would benefit gemms which have large K dimension.
implementable && (args.group_size == K || ((size<2>(TileShape{})) % args.group_size == 0));.
We simply made it compile time parameter for perf CC @IwakuraRein
@azhurkevich No, group_size is still a runtime argument at this stage.
@NihalPotdar The condition (args.group_size % size<2>(TileShape{})) == 0 is too strong. I aggree with you. It forces the whole row of a matrix A tile to use the same scale value, but I think it's reasonable to apply this only to the row in a wgmma instruction. (since SmemLayoutAtomScale's second dimension is 1). Therefore, the condition should be size<2>(TileShape{}) % args.group_size == 0 && args.group_size % gmma_shape_k == 0.
For instance, if a tile is 64x128x64, a thread will launch four 64x128x16 wgmma in total to compute a tile, and the group size can be as low as 16 IMO. Let me delve into the codes and see if this can be fixed.
@NihalPotdar Oh sry I made a mistake. SmemLayoutAtomScale and ScaleTileShape's second dimension being 1 is broadcasting 1 scale value to the whole K dimension of a tile. To make the granularity become the wgmma's K dimension, ScaleTileShape needs to be updated, and that also requires the group_size to be a compile time value so currently there is no easy fix for this.
@NihalPotdar can you please explain if this is a blocker for you, in details.
@azhurkevich yes it is. I was running some matmuls and benchmarking their performance using this code for uint4 and fp16 datatypes. However, when the K dimension is large and the other dimensions are smaller (say (m*4096x8192)), this leads to suboptimal performance and bandwidth utilization. For these cases, the peak bandwidth utilization I have observed is close to 30%. My thinking is that if the tile size K was allowed to be larger than group size, this might alleviate some of these issues.
@NihalPotdar can you please provide more detailed data of what you've encountered and what are your expectations pls. Thank you
@azhurkevich sure. So, I was working with this example code. If we set the mmaType to float16 (cutlass::half_t) and the quantType to uint4 (cutlass::uint4).
For the problem size, M=16, N=2560, K=8192. This problem size is memory bound, dominated by how fast we can read from the HBM into the SMs. I found the "most optimal results" when I get tile_m = 64, tile_n = 16, and tile_k = 64 with the KernelTmaWarpSpecializedMixedInput kernel. The group size I am using for testing in this case is 128.
However, even for these optimal results, I used ncu to profile that the maximum dram utilization is ~30%. This is not great and my hunch is that increasing the tile_k size where tile_k > group_size can help since (I think) this problem size is limited by the number of mac loop iterations. StreamK will not help in this case due to the overheads associated with that scheduling strategy for a small problem size, so parallelizing across K is not an option.
^ so being able to set the tile_k > group_size would be great!
@NihalPotdar taking into account the fact that most likely you are taking advantage of default SwapAB. Where M corresponds with second operand of tile shape (hence you can do 16) and N with first one. MN are responsible for parallelization across CTAs. With your problem shape you are currently launching 40 CTAs on a 132 SM GPU hence you are seeing such low utilization. You are correct that StreamK adds additional latencies. However, StreamK is quite well positioned for your problem shape. Typically I see some benefit and getting better on K>8k. Also please try SplitK, it has less overheads vs StreamK. This should help you.
@NihalPotdar Sry I don't quite understand why making tile_k > group_size will help in this case. When tile_k > group_size, the scale matrix will become larger and thus you're loading more data from device memory. Also, inside a tile, you're not broadcasting 1 scale value to the whole k dimension so you're also loading more data from shared memory as well. As for the computation, the number of multiplications will always be TileM x TileK (since this kernel does scaling before the MMA), which is irrelevant to the group size.
IMHO the only reason why small group size is needed is to reduce the accuracy loss from the quantization. It won't improve the latency of the kernel.
@IwakuraRein The core reason is likely the low utilization of the SMs. So is there any expected support for Stream-K or Split-K?
@ZZBoom StreamK/SplitK is only available for cooperative schedule as I can remember. For cooperative schedule, the tile size M must be greater than 128. However, with warp special only or ping-pong schedule it can go down to 64, which might be more efficient than using StreamK/SplitK but please play around with it and figure out what's the best for your use case.
To enable StreamK, add cutlass::gemm::StreamKScheduler to every GemmUniversal's template. E.g.,
using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopScaleOnly,
CollectiveEpilogue,
cutlass::gemm::StreamKScheduler
>;
Also notice that gemm.initialize(arguments, workspace.get()) is needed before each gemm.run() so it's important to add it in the profiling loop:
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
To further enable SplitK, also add the split number to the GEMM arguments. E.g.,
template <typename Args>
Args args_from_options(Options const& options) {
// ...
auto args = Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
args.scheduler.splits = 2;
return args;
}
@IwakuraRein Thank you for your prompt response! I was referring to whether stream-k or split-k was suitable for the problem size(M=16, N=2560, K=8192) mentioned above, because of the low SM utilization.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.