cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] how to use groupwise scaling along M for FP8 gemm to impelement per-token-per-128-channel and blockwise?

Open yizhang2077 opened this issue 1 year ago • 7 comments

What is your question? Hi, I try to use KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum to implement deepseekv3 block-wise FP8 as well as per-token-per-128-channel, but I find it does not work. While when I just replace the sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp with the same file in vllm it can work correctly. I think this commit https://github.com/vllm-project/vllm/pull/11868/commits/d963eb47f0c7934b793ec37e65f212c7890072db is critical I am not familar with cutlass, can someone help me to figure out what the problem is? Thank you very much!

Base code

template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
                                         const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
  using ElementAccumulator = float;
  using ElementCompute = float;
  using ElementBlockScale = float;

  using ElementA = cutlass::float_e4m3_t;
  using LayoutA = cutlass::layout::RowMajor;
  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

  using ElementB = cutlass::float_e4m3_t;
  using LayoutB = cutlass::layout::ColumnMajor;
  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

  using ElementC = void;
  using LayoutC = cutlass::layout::RowMajor;
  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;

  using ElementD = OutType;
  using LayoutD = cutlass::layout::RowMajor;
  constexpr int AlignmentD = AlignmentC;

  using ArchTag = cutlass::arch::Sm90;
  using OperatorClass = cutlass::arch::OpClassTensorOp;
  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
  using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
  using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;

  using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;
  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC,
      LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
      TileShape, ClusterShape,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
          sizeof(typename CollectiveEpilogue::SharedStorage))>,
      KernelSchedule>::CollectiveOp;

  using GemmKernel =
      cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,  // Indicates ProblemShape
                                           CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>;
  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  Gemm gemm_op;

  int m = a.size(0);
  int k = a.size(1);
  int n = b.size(1);

  auto a_ptr = static_cast<ElementA*>(a.data_ptr());
  auto b_ptr = static_cast<ElementB*>(b.data_ptr());
  auto o_ptr = static_cast<ElementD*>(out.data_ptr());

  auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
  auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
  StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
  StrideC stride_c;
  StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));

  typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
  typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};

  typename Gemm::Arguments args = {
      cutlass::gemm::GemmUniversalMode::kGemm,
      {m, n, k, 1},
      mainloop_args,
      epilogue_args,
  };

  size_t workspace_size = gemm_op.get_workspace_size(args);
  auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  auto workspace = torch::empty(workspace_size, workspace_options);
  auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

  auto can_implement = gemm_op.can_implement(args);
  TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))

  auto status = gemm_op.run(args, workspace.data_ptr(), stream);
  TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}

yizhang2077 avatar Feb 07 '25 15:02 yizhang2077

@yizhang2077 I am also working on this using similar codes, which can be compiled successfully, but produce incorrect results. Have you solved this problem?

xuzhenqi avatar Feb 09 '25 11:02 xuzhenqi

Vllm's version works.

xuzhenqi avatar Feb 09 '25 12:02 xuzhenqi

@xuzhenqi Still not,if you have any progress, please let me know, thank you very much!

yizhang2077 avatar Feb 09 '25 12:02 yizhang2077

The stride of scaleA in the PR is different from that in vllm. Absolutely the stride of scaleA in vllm is more popularly used, so I'm looking forward to vllm's PR too.

soundOfDestiny avatar Feb 10 '25 05:02 soundOfDestiny

PR opened: https://github.com/NVIDIA/cutlass/pull/2095

LucasWilkinson avatar Feb 10 '25 21:02 LucasWilkinson

PR opened: #2095

Hello, @LucasWilkinson. Thx for your PR, really helps me a lot. I have one little question about scale layouts - why scaleA has a stride like (1, M)? Will this layout improve the copying of scaleA in cutlass?

MetaBlues avatar Feb 18 '25 11:02 MetaBlues

I have one little question about scale layouts - why scaleA has a stride like (1, M)? Will this layout improve the copying of scaleA in cutlass?

Yes, just makes it so the A scales loads are coalesced when a block/tile loads them into shared memory

LucasWilkinson avatar Feb 18 '25 14:02 LucasWilkinson

I see https://github.com/NVIDIA/cutlass/pull/2095 has merged, thanks a lot! @LucasWilkinson

yizhang2077 avatar Feb 28 '25 07:02 yizhang2077