cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST]how to understand wgmma swizzle atom?

Open cermeng opened this issue 8 months ago • 4 comments

I'm confusing about the wgmma swizzle atom (code)

// M|N-major GMMA layouts in units of bits
using Layout_MN_INTER_Atom_Bits = ComposedLayout<Swizzle<0,4,3>, smem_ptr_flag, Layout<Shape< _128,_8>,Stride<_1, _128>>>;
using Layout_MN_SW32_Atom_Bits  = ComposedLayout<Swizzle<1,4,3>, smem_ptr_flag, Layout<Shape< _256,_8>,Stride<_1, _256>>>;
using Layout_MN_SW64_Atom_Bits  = ComposedLayout<Swizzle<2,4,3>, smem_ptr_flag, Layout<Shape< _512,_8>,Stride<_1, _512>>>;
using Layout_MN_SW128_Atom_Bits = ComposedLayout<Swizzle<3,4,3>, smem_ptr_flag, Layout<Shape<_1024,_8>,Stride<_1,_1024>>>;

// K-major GMMA layouts in units of bits
using Layout_K_INTER_Atom_Bits  = ComposedLayout<Swizzle<0,4,3>, smem_ptr_flag, Layout<Shape<_8, _128>,Stride< _128,_1>>>;
using Layout_K_SW32_Atom_Bits   = ComposedLayout<Swizzle<1,4,3>, smem_ptr_flag, Layout<Shape<_8, _256>,Stride< _256,_1>>>;
using Layout_K_SW64_Atom_Bits   = ComposedLayout<Swizzle<2,4,3>, smem_ptr_flag, Layout<Shape<_8, _512>,Stride< _512,_1>>>;
using Layout_K_SW128_Atom_Bits  = ComposedLayout<Swizzle<3,4,3>, smem_ptr_flag, Layout<Shape<_8,_1024>,Stride<_1024,_1>>>;

// M|N-major layouts in units of Type
template <class Type>
using Layout_MN_INTER_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_MN_INTER_Atom_Bits{}));
template <class Type>
using Layout_MN_SW32_Atom  = decltype(upcast<sizeof_bits<Type>::value>(Layout_MN_SW32_Atom_Bits{}));
template <class Type>
using Layout_MN_SW64_Atom  = decltype(upcast<sizeof_bits<Type>::value>(Layout_MN_SW64_Atom_Bits{}));
template <class Type>
using Layout_MN_SW128_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_MN_SW128_Atom_Bits{}));

// K-major layouts in units of Type
template <class Type>
using Layout_K_INTER_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_INTER_Atom_Bits{}));
template <class Type>
using Layout_K_SW32_Atom  = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW32_Atom_Bits{}));
template <class Type>
using Layout_K_SW64_Atom  = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW64_Atom_Bits{}));
template <class Type>
using Layout_K_SW128_Atom = decltype(upcast<sizeof_bits<Type>::value>(Layout_K_SW128_Atom_Bits{}));

Question 1:

Take 128B K-major swizzle as an example. According PTX doc, I can write the layout in the 128B K-major swizzle mode figure as

ComposedLayout<Swizzle<3,0,3>, Int<0>, Layout<Shape<_8, _8>,Stride<_8,_1>>>

Each element in this layout is 128 bits.

We can also write it in the units of bits as

ComposedLayout<Swizzle<3,7,3>, Int<0>, Layout<Shape<_8, _1024>,Stride<_1024,_1>>>

However, the implementation is

Layout_K_SW128_Atom_Bits  = ComposedLayout<Swizzle<3,4,3>, smem_ptr_flag, Layout<Shape<_8,_1024>,Stride<_1024,_1>>>;

How to understand it? why the second param (the number of least-sig bits to keep constant)of Swizzle is 4? What is the meaning of these 2^4=16 bits?

Question 2:

I notice when we call print_layout or print_latex for the above ComposedLayout of swizzle atom, a recast operation is applied as follows

template <class SwizzleFn, int B, class Layout>
CUTE_HOST_DEVICE
auto
as_position_independent_swizzle_layout(ComposedLayout<SwizzleFn,smem_ptr_flag_bits<B>,Layout> const& layout)
{
  return composition(recast_layout<uint8_t,uint_bit_t<B>>(layout.layout_a()), Int<0>{}, layout.layout_b());
}

Why this function is applied? Why uint8_t type is used as a reference?

I found it is consistent with my understanding mentioned in Q1 after applying this function.

auto sw = as_position_independent_swizzle_layout(Layout_K_SW128_Atom_Bits{});
print(sw); // Sw<3,7,3> o _0 o (_8,_1024):(_1024,_1)

@ccecka

cermeng avatar Jun 09 '25 11:06 cermeng

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jul 10 '25 08:07 github-actions[bot]

same question

irasin avatar Sep 01 '25 10:09 irasin

The swizzle layout you are pointing out is in units of bits, not bytes. What the bits of the swizzle mean etc is documented in the swizzle.hpp header file.

The dtype is int8 because that's one byte, no deeper reason than that.

Why this function is applied?

To convert the position dependent swizzle to a position independent swizzle for printing.

thakkarV avatar Sep 01 '25 13:09 thakkarV

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Nov 25 '25 14:11 github-actions[bot]