[QST] how to use groupwise scaling along M for FP8 gemm to impelement per-token-per-128-channel and blockwise?
What is your question?
Hi, I try to use KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum to implement deepseekv3 block-wise FP8 as well as per-token-per-128-channel, but I find it does not work. While when I just replace the sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp with the same file in vllm it can work correctly.
I think this commit https://github.com/vllm-project/vllm/pull/11868/commits/d963eb47f0c7934b793ec37e65f212c7890072db is critical
I am not familar with cutlass, can someone help me to figure out what the problem is? Thank you very much!
Base code
template <typename OutType, typename TileShape, typename ClusterShape, int ScaleGranularityM = 1>
void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b) {
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
using ElementA = cutlass::float_e4m3_t;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = cutlass::float_e4m3_t;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutType>::value;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC,
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Gemm gemm_op;
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
auto o_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_s_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
auto b_s_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
mainloop_args,
epilogue_args,
};
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
@yizhang2077 I am also working on this using similar codes, which can be compiled successfully, but produce incorrect results. Have you solved this problem?
Vllm's version works.
@xuzhenqi Still not,if you have any progress, please let me know, thank you very much!
The stride of scaleA in the PR is different from that in vllm. Absolutely the stride of scaleA in vllm is more popularly used, so I'm looking forward to vllm's PR too.
PR opened: https://github.com/NVIDIA/cutlass/pull/2095
PR opened: #2095
Hello, @LucasWilkinson.
Thx for your PR, really helps me a lot. I have one little question about scale layouts - why scaleA has a stride like (1, M)? Will this layout improve the copying of scaleA in cutlass?
I have one little question about scale layouts - why scaleA has a stride like (1, M)? Will this layout improve the copying of scaleA in cutlass?
Yes, just makes it so the A scales loads are coalesced when a block/tile loads them into shared memory
I see https://github.com/NVIDIA/cutlass/pull/2095 has merged, thanks a lot! @LucasWilkinson