cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] TMA Cooperative GeMM with Stream-K scheduler hangs for specific gemm shapes

Open Algy opened this issue 1 year ago • 4 comments

Describe the bug

Gemm kernels with the following configurations hang for specific gemm shapes.

  • Type: e4m3 x e4m3 -> bf16
  • Tile: 256x32x128
  • Cluster: 2x1x1
  • Kernel Schedule: KernelTmaWarpSpecializedCooperative
  • Epilogue Schedule: TmaWarpSpecializedCooperative
  • Tile Scheduler: Stream-K

Tested gemm shapes(MxNxK):

  • 3584x1x4736: Hang
  • 3328x1x4736: Hang
  • 3200x1x4736: Hang
  • 3136x1x4736: Hang
  • 3104x1x4736: Hang
  • 3088x1x4736: Hang
  • 3328x1x4736: Hang
  • 3200x1x4736: Hang
  • 3136x1x4736: Hang
  • 3104x1x4736: Hang
  • 3088x1x4736: Hang
  • 3072x1x4736: OK

When I change the epilogue schedule to NoSmemWarpSpecialized, this issue seems to disappear. Therefore, I guess there's something wrong with the TMA epilogue when it is used with Stream-K.

Steps/Code to reproduce bug

Apply the following patch file to 48_hopper_warp_specialized_gemm.cu:

diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
index f26f4da3..da827d6d 100644
--- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
+++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu
@@ -60,6 +60,8 @@
 
 #include "cute/tensor.hpp"
 #include "cutlass/tensor_ref.h"
+#include "cutlass/float8.h"
+#include "cutlass/bfloat16.h"
 #include "cutlass/epilogue/collective/default_epilogue.hpp"
 #include "cutlass/epilogue/thread/linear_combination.h"
 #include "cutlass/gemm/dispatch_policy.hpp"
@@ -89,17 +91,17 @@ using namespace cute;
 /////////////////////////////////////////////////////////////////////////////////////////////////
 
 // A matrix configuration
-using         ElementA    = float;                                          // Element type for A matrix operand
+using         ElementA    = cutlass::float_e4m3_t;                                          // Element type for A matrix operand
 using         LayoutA     = cutlass::layout::RowMajor;                      // Layout type for A matrix operand
 constexpr int AlignmentA  = 128 / cutlass::sizeof_bits<ElementA>::value;    // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
 
 // B matrix configuration
-using         ElementB    = float;                                          // Element type for B matrix operand
+using         ElementB    = cutlass::float_e4m3_t;                                          // Element type for B matrix operand
 using         LayoutB     = cutlass::layout::ColumnMajor;                   // Layout type for B matrix operand
 constexpr int AlignmentB  = 128 / cutlass::sizeof_bits<ElementB>::value;    // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
 
 // C/D matrix configuration
-using         ElementC    = float;                                          // Element type for C and D matrix operands
+using         ElementC    = cutlass::bfloat16_t;                                          // Element type for C and D matrix operands
 using         LayoutC     = cutlass::layout::ColumnMajor;                   // Layout type for C and D matrix operands
 constexpr int AlignmentC  = 128 / cutlass::sizeof_bits<ElementC>::value;    // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
 
@@ -107,8 +109,8 @@ constexpr int AlignmentC  = 128 / cutlass::sizeof_bits<ElementC>::value;    // M
 using ElementAccumulator  = float;                                          // Element type for internal accumulation
 using ArchTag             = cutlass::arch::Sm90;                            // Tag indicating the minimum SM that supports the intended feature
 using OperatorClass       = cutlass::arch::OpClassTensorOp;                 // Operator class tag
-using TileShape           = Shape<_128,_128,_32>;                           // Threadblock-level tile size
-using ClusterShape        = Shape<_1,_2,_1>;                                // Shape of the threadblocks in a cluster
+using TileShape           = Shape<_256,_32,_128>;                           // Threadblock-level tile size
+using ClusterShape        = Shape<_2,_1,_1>;                                // Shape of the threadblocks in a cluster
 using StageCountType = cutlass::gemm::collective::StageCountAuto;           // Stage count maximized based on the tile size
 using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;       // Kernel to launch based on the default setting in the Collective Builder
 
@@ -119,7 +121,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
     ElementAccumulator, ElementAccumulator,
     ElementC, LayoutC, AlignmentC,
     ElementC, LayoutC, AlignmentC,
-    cutlass::epilogue::collective::EpilogueScheduleAuto
+    cutlass::epilogue::TmaWarpSpecializedCooperative
   >::CollectiveOp;
 
 using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
@@ -130,13 +132,14 @@ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder
     TileShape, ClusterShape,
     cutlass::gemm::collective::StageCountAutoCarveout<
       static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
-    cutlass::gemm::collective::KernelScheduleAuto
+    cutlass::gemm::KernelTmaWarpSpecializedCooperative
   >::CollectiveOp;
 
 using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
     Shape<int,int,int>, // Indicates ProblemShape
     CollectiveMainloop,
-    CollectiveEpilogue
+    CollectiveEpilogue,
+    cutlass::gemm::StreamKScheduler
 >;
 
 using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
@@ -303,14 +306,14 @@ bool initialize_block(
   int bits_input = cutlass::sizeof_bits<Element>::value;
 
   if (bits_input == 1) {
-    scope_max = 2;
-    scope_min = 0;
+    scope_max = Element(2);
+    scope_min = Element(0);
   } else if (bits_input <= 8) {
-    scope_max = 2;
-    scope_min = -2;
+    scope_max = Element(2);
+    scope_min = Element(-2);
   } else {
-    scope_max = 8;
-    scope_min = -8;
+    scope_max = Element(8);
+    scope_min = Element(-8);
   }
 
   cutlass::reference::device::BlockFillRandomUniform(

(To apply the patch, use patch -p1 < xxx.patch)

Then execute the example with the command

./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=3584 --n=1 --k=4736

Environment details

  • Environment location: Bare metal on H100 80GB HBM3

Algy avatar Sep 10 '24 05:09 Algy

@jackkosaian

thakkarV avatar Sep 12 '24 13:09 thakkarV

Thanks for reporting. This is due to a bug in the CUTLASS 3.x implementation of "separate reduction." For the time being, you can circumvent this with the following change, which go this problem size to work for me.

diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h
index 36888a29..46adb3ed 100644
--- a/include/cutlass/gemm/kernel/tile_scheduler_params.h
+++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h
@@ -1047,11 +1047,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
   CUTLASS_HOST_DEVICE
   static bool
   should_perform_separate_reduction(uint32_t epilogue_subtile, uint64_t sk_units, uint64_t sk_tiles, uint64_t dp_tiles, uint64_t ctas_per_wave) {
-    // We perform separate reduction if we have fewer than one wave of output tiles
-    // and each output tile is covered by at least to stream-K units. When sk_units is
-    // multiple of sk_tiles, will choose basic split-k path instead of separate reduction for now.
-    return (epilogue_subtile != 1) && (dp_tiles == 0) && (sk_units > 2u * sk_tiles) &&
-           (sk_units + sk_tiles * epilogue_subtile <= ctas_per_wave);
+    return false;
   }

   // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when

jackkosaian avatar Sep 19 '24 16:09 jackkosaian

How long is this bug expected to be fixed on the main branch? If it takes pretty long, maybe I should fork the branch and use it with the patch you provided. The buggy GeMM shapes are from LLMs which are pretty popular now.

And I also wonder if there's any performance implication applying your patch? That is to say, is there any potential performance penalty when I always turn off the separate reduction?

Algy avatar Sep 20 '24 04:09 Algy

There is no timeline for when the separate reduction implementation will be fixed. We plan to roll out the patch I described soon, though.

There is no performance implication because, as far as I have seen, separate reduction is currently broken in any of its use cases.

jackkosaian avatar Sep 20 '24 15:09 jackkosaian

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Oct 20 '24 16:10 github-actions[bot]

@jackkosaian curious how long the separate-reduction fix is expected to take and any suggested workarounds?

My understanding is that for small GEMM shapes with large K dimension, separate reduction would be very helpful and since it's disabled, this directly affects the performance for these GEMMs. One such GEMM configuration is m=16,n=2560,k=8192.

NihalPotdar avatar Nov 02 '24 01:11 NihalPotdar

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Dec 02 '24 01:12 github-actions[bot]