[QST] Tile is not evenly divisible
What is your question? I'm doing matrix multiplication manually with cute
Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{}));
Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{}));
Tensor C = make_tensor(make_gmem_ptr(Cptr), make_shape(m, n), make_stride(n, Int<1>{}));
Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _));
Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _));
Tensor gC = local_tile(C, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix));
When m=10, kTileM=16, I observe that gA will repeat A to fill up to the size of kTileM, and then the memory addresses of the rest of the thread will start to get confused.
I observed this using print tensor
I want to know how to handle this non-divisible situation. Are there any examples in cute that you can refer to
There is a dedicated document on how to handle predication in the cute markdown docs.
https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0y_predication.md
Ok, thanks, I forgot to read this one, I'll give it a try, but I feel the documentation should be accompanied by some code examples
It is. Our Sm80 mainloop implements predication: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm80_mma_multistage.hpp#L504
I see, thank you very much for your timely reply, I am experimenting
I have a new problem, and I think I fulfilled the predicate normally
// Allocate predicate tensors for m and n
auto tApA = make_tensor<bool>(make_shape(size<1>(tAsA_copy), size<2>(tAsA_copy)), Stride<_1, _0>{});
auto tBpB = make_tensor<bool>(make_shape(size<1>(tBsB_copy), size<2>(tBsB_copy)), Stride<_1, _0>{});
// Construct identity layout for sA and sB
auto cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
auto cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
auto tAcA = g2s_thr_copy_a.partition_S(cA); // (CPY, CPY_M, CPY_K, kStage) -> (blk_m,blk_k)
auto tBcB = g2s_thr_copy_b.partition_S(cB); // (CPY, CPY_M, CPY_K, kStage) -> (blk_n,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<0>(tApA); ++m)
{
tApA(m, 0) = get<0>(tAcA(0, m, 0)) < M; // blk_m coord < residue_m
}
// Set predicates for n bounds
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<0>(tBpB); ++n)
{
tBpB(n, 0) = get<0>(tBcB(0, n, 0)) < N; // blk_n coord < residue_n
}
I printed it and observed it seemed fine;
My new problem is that when I type M<16, i.e. cannot divide TileM exactly, I print my global A and I find that it changes every time, as shown in the picture the first time
What is your question? I'm doing matrix multiplication manually with cute
Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{})); Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{})); Tensor C = make_tensor(make_gmem_ptr(Cptr), make_shape(m, n), make_stride(n, Int<1>{})); Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _)); Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _)); Tensor gC = local_tile(C, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix));When m=10, kTileM=16, I observe that gA will repeat A to fill up to the size of kTileM, and then the memory addresses of the rest of the thread will start to get confused.
I observed this using print tensor
I want to know how to handle this non-divisible situation. Are there any examples in cute that you can refer to
I conducted further experiments and found that if I commented out the last copy back to gC, the gA would not change;
cute::copy(tCrC, tCgC);
If my M is exactly 16 and TileM is 16, then gA is not going to change The problem only occurs when M is not divisible and there is a copy-back gC operation
Supplementary gA printing
thread 0 step 0:
thread 0 step 1:
post my run code:
#include "helper.h"
#include <cublas_v2.h>
#include <cuda.h>
#include <cute/tensor.hpp>
#include <stdlib.h>
template <typename T>
void gen_rand_data(T *data, int n);
template <typename T, typename Config>
__global__ void gemm_simple(T *Cptr, const T *Aptr, const T *Bptr, int M, int N, int K)
{
using namespace cute;
using X = Underscore;
// using T = typename Config::T;
using SmemLayoutA = typename Config::SmemLayoutA;
using SmemLayoutB = typename Config::SmemLayoutB;
using SmemLayoutC = typename Config::SmemLayoutC;
using TiledMMA = typename Config::MMA;
using S2RCopyAtomA = typename Config::S2RCopyAtomA;
using S2RCopyAtomB = typename Config::S2RCopyAtomB;
using G2SCopyA = typename Config::G2SCopyA;
using G2SCopyB = typename Config::G2SCopyB;
using R2SCopyAtomC = typename Config::R2SCopyAtomC;
using S2GCopyAtomC = typename Config::S2GCopyAtomC;
using S2GCopyC = typename Config::S2GCopyC;
constexpr int kTileM = Config::kTileM;
constexpr int kTileN = Config::kTileN;
constexpr int kTileK = Config::kTileK;
constexpr int kStage = Config::kStage;
int idx = threadIdx.x;
int ix = blockIdx.x;
int iy = blockIdx.y;
Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(M, K), make_stride(K, Int<1>{}));
Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(N, K), make_stride(K, Int<1>{}));
Tensor C = make_tensor(make_gmem_ptr(Cptr), make_shape(M, N), make_stride(N, Int<1>{}));
// 创建Tile的全局内存
Tensor gA = local_tile(A, make_tile(Int<kTileM>{}, Int<kTileK>{}), make_coord(iy, _)); // gA(kTileM, kTileK, num_tile_k)
Tensor gB = local_tile(B, make_tile(Int<kTileN>{}, Int<kTileK>{}), make_coord(ix, _)); // gB(kTileN, kTileK, num_tile_k)
Tensor gC = local_tile(C, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(iy, ix)); // gC(kTileM, kTileN)
// 创建Tile的共享内存
extern __shared__ T shm_data[];
T *Ashm = shm_data;
T *Bshm = shm_data + cute::cosize(SmemLayoutA{});
auto sA = make_tensor(make_smem_ptr(Ashm), SmemLayoutA{}); // (kTileM, kTileK, kStage)
auto sB = make_tensor(make_smem_ptr(Bshm), SmemLayoutB{}); // (kTileN, kTileK, kStage)
//
// 声明
//
// 声明mma和copy
TiledMMA tiled_mma;
G2SCopyA g2s_tiled_copy_a;
G2SCopyB g2s_tiled_copy_b;
// 线程级的copy
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx);
auto g2s_thr_copy_b = g2s_tiled_copy_b.get_slice(idx);
// tiled copy; shared to register
auto s2r_tiled_copy_a = make_tiled_copy_A(S2RCopyAtomA{}, tiled_mma);
auto s2r_tiled_copy_b = make_tiled_copy_B(S2RCopyAtomB{}, tiled_mma);
// 线程级的mma
auto thr_mma = tiled_mma.get_slice(idx);
// gA for thread A;
auto tAgA = thr_mma.partition_A(gA); // (MMA, MMA_M, MMA_K, num_tile_k)
auto tBgB = thr_mma.partition_B(gB); // (MMA, MMA_N, MMA_K, num_tile_k)
auto tCgC = thr_mma.partition_C(gC); // (MMA, MMA_M, MMA_N)
// 值等同于tAgA,但是根据copy atom作了一个copy逻辑映射
auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, k)
auto tBgB_copy = g2s_thr_copy_b.partition_S(gB); // (CPY, CPY_N, CPY_K, k)
// 同理
auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K, kStage)
auto tBsB_copy = g2s_thr_copy_b.partition_D(sB); // (CPY, CPY_N, CPY_K, kStage)
//
// predicate,处理MNK非整除
//
// Allocate predicate tensors for m and n
auto tApA = make_tensor<bool>(make_shape(size<1>(tAsA_copy), size<2>(tAsA_copy)), Stride<_1, _0>{});
auto tBpB = make_tensor<bool>(make_shape(size<1>(tBsB_copy), size<2>(tBsB_copy)), Stride<_1, _0>{});
// Construct identity layout for sA and sB
auto cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
auto cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
auto tAcA = g2s_thr_copy_a.partition_S(cA); // (CPY, CPY_M, CPY_K, kStage) -> (blk_m,blk_k)
auto tBcB = g2s_thr_copy_b.partition_S(cB); // (CPY, CPY_M, CPY_K, kStage) -> (blk_n,blk_k)
// Set predicates for m bounds
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<0>(tApA); ++m)
{
tApA(m, 0) = get<0>(tAcA(0, m, 0)) < M; // blk_m coord < residue_m
}
// Set predicates for n bounds
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<0>(tBpB); ++n)
{
tBpB(n, 0) = get<0>(tBcB(0, n, 0)) < N; // blk_n coord < residue_n
}
// Clear the smem tiles to account for predicated off loads
clear(tAsA_copy);
clear(tBsB_copy);
// if(thread(0)) {
// // print("\n\ntAgA\n");
// // print(tAgA.shape()); print("\n"); print(tAgA.stride());
// // print_tensor(tAgA);
// printf("block:(%d, %d);thread: %d\n", ix,iy,idx);
// print("\n\ntApA\n");
// print(tApA.shape()); print("\n"); print(tApA.stride());
// print_tensor(tApA);
// print(size<0>(tApA));
// // print("\n\ntAsA_copy\n");
// // print(tAsA_copy.shape()); print("\n"); print(tAsA_copy.stride());
// // print_tensor(tAsA_copy);
// // print(size<1>(tAsA_copy));
// // print("\n\ncA\n");
// // print(cA.shape()); print("\n"); print(cA.stride());
// // print_tensor(cA);
// print("\n\ntAcA\n");
// print(tAcA.shape()); print("\n"); print(tAcA.stride());
// print_tensor(tAcA);
// print("\n\ntBcB\n");
// print(tBcB.shape()); print("\n"); print(tBcB.stride());
// print_tensor(tBcB);
// }
auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); // (MMA, MMA_M, MMA_K)
auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); // (MMA, MMA_N, MMA_K)
auto tCrC = thr_mma.partition_fragment_C(gC); // (MMA, MMA_M, MMA_N)
auto s2r_thr_copy_a = s2r_tiled_copy_a.get_slice(idx);
auto tCsA = s2r_thr_copy_a.partition_S(sA); // ? (CPY, CPY_M, CPY_K, kStage)
auto tCrA_view = s2r_thr_copy_a.retile_D(tCrA); // ? (CPY, CPY_M, CPY_K)
auto s2r_thr_copy_b = s2r_tiled_copy_b.get_slice(idx);
auto tCsB = s2r_thr_copy_b.partition_S(sB); // ? (CPY, CPY_M, CPY_K, kStage)
auto tCrB_view = s2r_thr_copy_b.retile_D(tCrB); // ? (CPY, CPY_M, CPY_K)
clear(tCrC);
if (thread(0))
{
print("\n\nmA\n");
print_tensor(A);
print("\n\ngA\n");
print_tensor(gA);
print("\n\ntApA\n");
print_tensor(tApA);
print("\n\ntBpB\n");
print_tensor(tBpB);
print("\n\ntCgC\n");
print_tensor(tCgC);
// print("\n\ntAgA\n");
// print(tAgA.shape()); print("\n"); print(tAgA.stride());
// print("\n\ntAgA_copy\n");
// print(tAgA_copy.shape()); print("\n"); print(tAgA_copy.stride());
// print("\n\ntAsA_copy\n");
// print(tAsA_copy.shape()); print("\n"); print(tAsA_copy.stride());
// print("\n\n");
// print("\n\ntAgA_copy\n");
// print_tensor(tAgA_copy);
// print("\n\ntAsA_copy\n");
// print_tensor(tAsA_copy);
print("\n\n");
}
int ntile = size<2>(gA);
#pragma unroll 1
for (int itile = 0; itile < ntile; ++itile)
{
// cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile), tAsA_copy(_, _, _, 0));
// cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile), tBsB_copy(_, _, _, 0));
cute::copy_if(g2s_tiled_copy_a, tApA, tAgA_copy(_, _, _, itile), tAsA_copy(_, _, _, 0));
cute::copy_if(g2s_tiled_copy_b, tBpB, tBgB_copy(_, _, _, itile), tBsB_copy(_, _, _, 0));
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
// if (thread(0) && itile == 0)
// {
// print("\n\ntAsA_copy\n");
// print_tensor(tAsA_copy(_, _, _, itile));
// print("\n\n");
// print("\n\ntBsB_copy\n");
// print_tensor(tBsB_copy(_, _, _, itile));
// print("\n\n");
// }
int nk = size<2>(tCrA);
#pragma unroll
for (int ik = 0; ik < nk; ++ik)
{
// shm -> reg s[itile][ik + 1] -> r[ik + 1]
cute::copy(s2r_tiled_copy_a, tCsA(_, _, ik, 0), tCrA_view(_, _, ik));
cute::copy(s2r_tiled_copy_b, tCsB(_, _, ik, 0), tCrB_view(_, _, ik));
// cute::copy_if(s2r_tiled_copy_a, tApA, tAsA(_, _, ik, 0), tArA_view(_, _, ik));
// cute::copy_if(s2r_tiled_copy_b, tBpB, tBsB(_, _, ik, 0), tBrB_view(_, _, ik));
cute::gemm(tiled_mma, tCrC, tCrA(_, _, ik), tCrB(_, _, ik), tCrC);
} // for ik
// cute::copy(tAgA(_, _, _, itile), tArA);
// cute::copy(tBgB(_, _, _, itile), tBrB);
// cute::gemm(tiled_mma, tCrC, tArA, tBrB, tCrC);
// if (thread(0) && itile == 0)
// {
// print("\n\ntCrC\n");
// print_tensor(tCrC);
// print("\n\n");
// }
} // itile
cute::copy(tCrC, tCgC);
}
namespace config
{
using namespace cute;
template <typename T_, int kTileM_ = 128, int kTileN_ = 128, int kTileK_ = 32,
int kStage_ = 5, int kSmemLayoutCBatch_ = 2,
typename ComputeType = T_>
struct GemmConfig
{
using T = T_;
// tile configuration
static constexpr int kTileM = kTileM_;
static constexpr int kTileN = kTileN_;
static constexpr int kTileK = kTileK_;
static constexpr int kStage = kStage_;
static constexpr int kSmemLayoutCBatch = kSmemLayoutCBatch_;
static constexpr int kShmLoadSwizzleM = 3;
static constexpr int kShmLoadSwizzleS = 3;
static constexpr int kShmLoadSwizzleB = 3;
using SmemLayoutAtom = decltype(composition(
Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
make_layout(make_shape(Int<8>{}, Int<kTileK>{}),
make_stride(Int<kTileK>{}, Int<1>{}))));
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtom{},
make_shape(Int<kTileM>{}, Int<kTileK>{}, Int<kStage>{})));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtom{},
make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{})));
using mma_op = SM80_16x8x16_F16F16F16F16_TN;
using mma_traits = MMA_Traits<mma_op>;
using mma_atom = MMA_Atom<mma_traits>;
static constexpr int kMmaEURepeatM = 1;
static constexpr int kMmaEURepeatN = 4;
static constexpr int kMmaEURepeatK = 1;
using mma_atom_shape = mma_traits::Shape_MNK;
static constexpr int kMmaPM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
static constexpr int kMmaPN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
static constexpr int kMmaPK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});
using MMA_EU_RepeatT = decltype(make_layout(make_shape(
Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
using MMA_P_T = Tile<Int<kMmaPM>, Int<kMmaPN>, Int<kMmaPK>>;
using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{}));
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, T>;
using G2SCopyA =
decltype(make_tiled_copy(g2s_copy_atom{},
make_layout(make_shape(Int<16>{}, Int<8>{}),
make_stride(Int<8>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
using G2SCopyB = G2SCopyA;
// shared memory to register copy
using s2r_copy_op = SM75_U32x4_LDSM_N;
using s2r_copy_traits = Copy_Traits<s2r_copy_op>;
using s2r_copy_atom = Copy_Atom<s2r_copy_traits, T>;
using S2RCopyAtomA = s2r_copy_atom;
using S2RCopyAtomB = s2r_copy_atom;
// epilogue: register to global via shared memory
using SmemLayoutAtomC = decltype(composition(
Swizzle<2, 3, 3>{}, make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}),
make_stride(Int<kMmaPN>{}, Int<1>{}))));
using SmemLayoutC = decltype(tile_to_shape(
SmemLayoutAtomC{},
make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}, Int<kSmemLayoutCBatch>{})));
// static_assert(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) >=
// size(SmemLayoutC{}),
// "C shared memory request is large than A's one pipe");
using R2SCopyAtomC = Copy_Atom<UniversalCopy<int>, T>;
using S2GCopyAtomC = Copy_Atom<UniversalCopy<cute::uint128_t>, T>;
using S2GCopyC =
decltype(make_tiled_copy(S2GCopyAtomC{},
make_layout(make_shape(Int<32>{}, Int<4>{}),
make_stride(Int<4>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
static constexpr int kThreadNum = size(MMA{});
static constexpr int shm_size_AB =
cute::cosize(SmemLayoutA{}) + cute::cosize(SmemLayoutB{});
static constexpr int shm_size_C = cute::cosize(SmemLayoutC{});
static constexpr int kShmSize =
cute::max(shm_size_AB, shm_size_C) * sizeof(T);
};
} // namespace config
int main()
{
// srand(10086);
using T = cute::half_t;
using namespace cute;
GpuTimer timer;
T *Cptr;
T *Aptr;
T *Bptr;
int m = 2;
int n = 64;
int k = 128;
cudaMalloc(&Cptr, sizeof(T) * m * n);
cudaMalloc(&Aptr, sizeof(T) * m * k);
cudaMalloc(&Bptr, sizeof(T) * k * n);
T *Aptr_host;
T *Bptr_host;
Aptr_host = (T *)malloc(sizeof(T) * m * k);
Bptr_host = (T *)malloc(sizeof(T) * n * k);
gen_rand_data(Aptr_host, m * k);
gen_rand_data(Bptr_host, n * k);
cudaMemcpy(Aptr, Aptr_host, sizeof(T) * m * k, cudaMemcpyHostToDevice);
cudaMemcpy(Bptr, Bptr_host, sizeof(T) * n * k, cudaMemcpyHostToDevice);
// cublas
T *Cptr_cublas;
cudaMalloc(&Cptr_cublas, sizeof(T) * m * n);
cublasHandle_t handle;
cublasCreate(&handle);
half alpha = half(1.f);
half beta = half(0.f);
timer.start();
for (int i = 0; i < 100; ++i)
{
cublasStatus_t ret = cublasHgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N,
n, m, k,
&alpha,
(half *)Bptr, k,
(half *)Aptr, k,
&beta,
(half *)Cptr_cublas, n);
// if (ret != CUBLAS_STATUS_SUCCESS)
// {
// printf("blas err = %d, str = %s\n", ret, cublasGetStatusString(ret));
// }
}
timer.stop();
printf("cublas avg runtime %f ms\n", timer.elapsed_millis());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
printf("err = %d, str = %s\n", err, cudaGetErrorString(err));
// cute
constexpr int kTileM = 16;
constexpr int kTileN = 128;
constexpr int kTileK = 128;
config::GemmConfig<T, kTileM, kTileN, kTileK, 1> gemm_config;
dim3 block = gemm_config.kThreadNum;
dim3 grid((n + gemm_config.kTileN - 1) / gemm_config.kTileN,
(m + gemm_config.kTileM - 1) / gemm_config.kTileM);
int shm_size = gemm_config.kShmSize;
std::cout << block << std::endl;
std::cout << grid << std::endl;
print(typename decltype(gemm_config)::MMA{});
cudaMemset(Cptr, 0, sizeof(T) * m * n);
timer.start();
for (int i = 0; i < 100; ++i)
{
cudaFuncSetAttribute(gemm_simple<T, decltype(gemm_config)>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
gemm_simple<T, decltype(gemm_config)><<<grid, block, shm_size>>>(Cptr, Aptr, Bptr, m, n, k);
}
cudaDeviceSynchronize();
timer.stop();
printf("cute avg runtime %f ms\n", timer.elapsed_millis());
err = cudaGetLastError();
printf("err = %d, str = %s\n", err, cudaGetErrorString(err));
T *Cptr_host;
T *Cptr_cublas_host;
Cptr_host = (T *)malloc(sizeof(T) * m * n);
Cptr_cublas_host = (T *)malloc(sizeof(T) * m * n);
// compare
cudaMemcpy(Cptr_host, Cptr, sizeof(T) * m * n, cudaMemcpyDeviceToHost);
cudaMemcpy(Cptr_cublas_host, Cptr_cublas, sizeof(T) * m * n, cudaMemcpyDeviceToHost);
// float threshold = 0.1;
// for (int i = 0; i < m * n; ++i) {
// float v1 = Cptr_host[i];
// float v2 = Cptr_cublas_host[i];
// printf("v1 = %f, v2 = %f\n", v1, v2);
// if (fabs(v2 - v1) > threshold) {
// printf("v1 = %f, v2 = %f\n", v1, v2);
// }
// }
Tensor tensor_C = make_tensor(Cptr_host, make_shape(m, n), make_stride(n, 1));
Tensor tensor_C_cublas = make_tensor(Cptr_cublas_host, make_shape(m, n), make_stride(n, 1));
auto tile = make_tile(8, 8);
auto coor = make_coord(0, 0);
Tensor tc1 = local_tile(tensor_C, tile, coor);
Tensor tc1_cublas = local_tile(tensor_C_cublas, tile, coor);
print_tensor(tc1);
print_tensor(tc1_cublas);
}
template <typename T>
void gen_rand_data(T *data, int n)
{
for (int i = 0; i < n; ++i)
{
// float v = (rand() % 200 - 100) * 0.01;
float v = i * 0.1;
data[i] = v;
}
}
Hope to get help, thank you
It is. Our Sm80 mainloop implements predication: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm80_mma_multistage.hpp#L504
Where is there a unit test or something that allows me to run this code?
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
Hi, all! I was wondering how the predication handles the address misaligned problem when K is not divisible by TileK. For example, cp.async requires the address to be at least 16Byte aligned (Please correct me if I am wrong), then given TileM = 128, TileK = 32, how should we handle the problem size M=4096, K=4097? If my understanding is correct, this example code https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm80_mma_multistage.hpp#L549 can only handle the case where k is the multiple of 8, right?
We simply shift the tensor base pointer so such that all but the 0th k-tile is aligned to the tile shape
https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp#L222
Hi, @thakkarV, thanks for your reply. It seems the shift mechanism is basicly the same between sm80 and sm90 mma multistage code. So I think my question is still unsolved, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/collective/sm80_mma_multistage.hpp#L549 can only handle the case where k is the multiple of 8 due to the requirement of cp.async src alignment.
If the input shape is not aligned to the instruction alignment, that shape is not computable by the kernel. you have to reduce the vectorization if you want to compute on lower alignment
Got it, thanks a lot~
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
