tinygrad icon indicating copy to clipboard operation
tinygrad copied to clipboard

Use the Tensor Cores on M1+ ($1000 bounty)

Open geohot opened this issue 2 years ago • 12 comments

Claim the bounty by implementing this and having tinygrad generate GEMM kernels for M1 that are faster than torch/Accelerate.framework.

Clean code only, must be merged to claim bounty.

geohot avatar Mar 25 '23 00:03 geohot

I’m working on a library that will make this task easy. Even easy for you to do yourself. I have a different priority at the moment but plan to implement Stream-K in a few weeks.

I have it outperforming MPS at around half of my test cases (notably FP16). If you want to get a head start, I can post some of the raw source code from my repo, and a few tables of perf data.

I’m not interested in the bounty just trying to prioritize my own interests.

philipturner avatar May 25 '23 23:05 philipturner

The value in the bounty for me is the integration into tinygrad. Though there's definitely value in open implementations with great perf numbers for M1, considering Accelerate is closed source. Are you using the async copy stuff?

geohot avatar May 26 '23 00:05 geohot

Are you using the async copy stuff?

I couldn't find it, and it's probably not needed. I suspect that the simdgroup matrix load/store API calls that under the hood. Seems like it utilizes texture hardware (arbitrary image width, texture pointer + sample position).

The value in the bounty for me is the integration into tinygrad.

Basically, my library will compile into a .metallib. It's designed for very easy integration + low overhead. It should support not just Apple 7, but also AMD and Intel. I'm waiting to open-source it until I'm ready. You seem to be the only other person with interest in the subject, hence why I'm talking on GitHub.

philipturner avatar May 26 '23 00:05 philipturner

I got this by fine-tuning the heck out of block sizes. Stream-K should make it possible with a single block size.

Large Square Sizes 256 x 256 x 256 384 x 384 x 384 512 x 512 x 512 768 x 768 x 768 1024 x 1024 x 1024 1280 x 1280 x 1280 1440 x 1440 x 1440
Accelerate F64 333 622 616 696 442
MFA F64* <590 <590 <590 <590 <590 <590 <590
Accelerate F32 1223 1303 2282 2679 2262
MPS F32 1836 3216 6200 6157 8143 7771 6497
MFA F32 1327 2632 3960 6656 6879 6752 7017
MPS F16 1730 4066 5849 5680 7336 7102 6433
MFA F16 2133 2590 4225 7569 8185 8058 8047
Large Square Sizes 256 x 256 x 256 384 x 384 x 384 512 x 512 x 512 768 x 768 x 768 1024 x 1024 x 1024 1280 x 1280 x 1280 1440 x 1440 x 1440
Faster than MPS?
MPSGraph F16 16.3% 38.3% 55.1% 53.5% 69.1% 66.9% 60.6%
MPSGraph F32 17.3% 30.3% 58.4% 58.0% 76.8% 73.2% 61.2%
naive_f16 4.07% 13.6% 14.2% 15.0% 14.9% 15.1% 15.3%
naive_f32 6.69% 9.50% 12.8% 13.0% 13.4% 11.3% 12.2%
blocked_f16 16.0% 23.6% 37.2% 55.5% 61.7% 55.4% 59.3%
blocked_f32 4.87% 17.4% 18.8% 30.6% 31.4% 28.0% 30.8%
threadgroup_f16 5.50% 16.4% 25.2% 50.5% 51.5% 58.9% 61.3%
threadgroup_f32 4.71% 11.2% 25.0% 30.5% 32.3% 32.4% 32.4%
simdgroup_f16 20.1% 24.4% 39.8% 71.3% 77.1% 75.9% 75.8%
simdgroup_f32 12.5% 24.8% 37.3% 62.7% 64.8% 63.6% 66.1%
streamk_f16
streamk_f32

philipturner avatar May 26 '23 00:05 philipturner

Which chip, and what are the units?

geohot avatar May 26 '23 00:05 geohot

Which chip, and what are the units?

32-core M1 Max. The first is GFLOPS (FP64 emulation is just a theoretical upper bound). The second is GPU ALU%. Use the flops from AppleGPUInfo not the 10.4 that Apple says.

philipturner avatar May 26 '23 00:05 philipturner

Do what you will with this. Using threadgroup memory inside gemm_simdgroup made it worse.

gemm_simdgroup
// Dimensions of the matrices.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];

// Stride for matrices during batched GEMM, in number of elements.
constant ulong A_stride [[function_constant(3)]];
constant ulong B_stride [[function_constant(4)]];
constant ulong C_stride [[function_constant(5)]];

// Sometimes abbreviated `tb_size`.
constant ushort thread_block_mn [[function_constant(10)]];
constant ushort thread_block_k = 1;

// Sometimes abbreviated `sb_size`.
constant ushort simdgroup_block_mn [[function_constant(12)]];
constant ushort simdgroup_block_k = 8;

// Sometimes abbreviated `tgb_size`.
constant ushort threadgroup_block_mn [[function_constant(14)]];
constant ushort threadgroup_block_k [[function_constant(15)]];

// Threads/Threadgroup = pow(`threadgroup_size`, 2)
// Except for `gemm_simdgroup`, where threadgroup size is always 32 or 128.
constant ushort threadgroup_size = threadgroup_block_mn / thread_block_mn;

// Whether to use threadgroup memory in `gemm_simdgroup`. This may be necessary
// for some unaligned matrix sizes (try to make it not necessary by exploiting
// knowledge about the layout pattern):
// https://patentimages.storage.googleapis.com/88/6d/9a/510adee9164f8f/US11256518.pdf
//
// TODO: Create separate load/store execution path for edges of matrices.
constant bool simdgroup_use_threadgroup [[function_constant(20)]];


//
//  GEMM_simdgroup.cpp
//  MetalFlashAttention
//
//  Created by Philip Turner on 5/18/23.
//

#include "../Common/Common.hpp"
#include "GEMM.hpp"

// Implementation of `gemm_simdgroup`

#ifdef __METAL__
// If not using threadgroups, then append `sid` to A, B, and C beforehand. This
// will now accept threadgroups with (1, 4), (2, 2) or (4, 1) shapes, and ignore
// the `sid` parameter.
//
// WARNING: Do not use `threadgroup_block_mn` in this loop.
template <
typename real,
ushort sb_size_mn,
ushort tgb_size_mn,
void init_acc(thread simdgroup_accumulator<real, sb_size_mn>&),
void write_threadgroup(threadgroup real*,
                       threadgroup real*,
                       thread simdgroup_matrix<real, 8, 8>*,
                       thread simdgroup_matrix<real, 8, 8>*,
                       ushort2, ushort2, ushort, ushort, ushort),
void read_threadgroup(const threadgroup real*,
                      const threadgroup real*,
                      thread simdgroup_matrix<real, 8, 8>*,
                      thread simdgroup_matrix<real, 8, 8>*,
                      ushort2, ushort2, ushort, ushort, ushort),
void outer_loop(const device real*,
                const device real*,
                thread simdgroup_matrix<real, 8, 8>*,
                thread simdgroup_matrix<real, 8, 8>*,
                uint2, uint2, uint, uint, uint),
void inner_loop(const thread simdgroup_matrix<real, 8, 8>*,
                const thread simdgroup_matrix<real, 8, 8>*,
                thread simdgroup_accumulator<real, sb_size_mn>&),
void store_acc(device real*,
               simdgroup_accumulator<real, sb_size_mn>,
               uint2, uint)
>
void _gemm_simdgroup
 (
  device real* A, device real* B, device real* C,
  threadgroup real* A_block, threadgroup real* B_block,
  thread uint2& A_index, thread uint2& B_index, thread uint2& C_index,
  ushort2 sid)
{
  // Do not use `threadgroup_block_mn` in the loop!
  typedef void threadgroup_block_mn;
  
  // Accumulator addressed by [y][x]
  // SIMD load addressed by [x][y]
  // m = y
  // n = x
  typedef simdgroup_matrix<real, 8, 8> simdgroup_real8x8;
  simdgroup_real8x8 A_value[sb_size_mn / 8];
  simdgroup_real8x8 B_value[sb_size_mn / 8];
  simdgroup_accumulator<real, sb_size_mn> C_value;
  
  // Initialize the accumulator.
  init_acc(C_value);
  
  if (simdgroup_use_threadgroup) {
    A_index.y += sid.x * sb_size_mn;
    B_index.x += sid.x * sb_size_mn;
    
    // Fusing the pointer with its index consumes some extra registers. It is
    // faster at smaller block sizes because there's less register pressure.
#define FUSE_DEVICE 1
#if FUSE_DEVICE
    auto A_src = A + A_index.y * K;
    auto B_src = B + B_index.x;
#endif
    
    for (uint k_floor = 0; k_floor < K; k_floor += threadgroup_block_k) {
      auto A_dst = A_block + sid.x * sb_size_mn * threadgroup_block_k;
      auto B_dst = B_block + sid.x * sb_size_mn;
      for (ushort k = sid.y * 8; k < threadgroup_block_k; k += 16) {
        // Temporarily store fetched tiles in SIMD matrices.
#if FUSE_DEVICE
        outer_loop(A_src, B_src, A_value, B_value, uint2(0), uint2(0), k, N, K);
#else
        outer_loop(A, B, A_value, B_value, A_index, B_index, k, N, K);
#endif
        
        // Cache in threadgroup memory.
#if 0
        ushort2 A_block_index(0, sid.x * sb_size_mn);
        ushort2 B_block_index(sid.x * sb_size_mn, 0);
        write_threadgroup
         (
          A_block, B_block, A_value, B_value, A_block_index, B_block_index,
          k, tgb_size_mn, threadgroup_block_k);
#else
        write_threadgroup
         (
          A_dst, B_dst, A_value, B_value, ushort2(0), ushort2(0),
          k, tgb_size_mn, threadgroup_block_k);
#endif
      }
      threadgroup_barrier(mem_flags::mem_threadgroup);
      
      auto A_cache = A_block + sid.y * sb_size_mn * threadgroup_block_k;
      auto B_cache = B_block + sid.x * sb_size_mn;
      for (ushort k = 0; k < threadgroup_block_k; k += 8) {
        // Read from lower-latency cache.
#if 0
        ushort2 A_block_index(0, sid.y * sb_size_mn);
        ushort2 B_block_index(sid.x * sb_size_mn, 0);
        read_threadgroup
         (
          A_block, B_block, A_value, B_value, A_block_index, B_block_index,
          k, tgb_size_mn, threadgroup_block_k);
#else
        read_threadgroup
         (
          A_cache, B_cache, A_value, B_value, ushort2(0), ushort2(0),
          k, tgb_size_mn, threadgroup_block_k);
#endif
        
        inner_loop(A_value, B_value, C_value);
      }
#if FUSE_DEVICE
      A_src += threadgroup_block_k;
      B_src += threadgroup_block_k * N;
#else
      A_index.x += threadgroup_block_k;
      B_index.y += threadgroup_block_k;
#endif
      threadgroup_barrier(mem_flags::mem_none);
    }
#undef FUSE_DEVICE
  } else {
    auto A_src = A + A_index.y * K + A_index.x;
    auto B_src = B + B_index.y * N + B_index.x;
    for (uint k_floor = 0; k_floor < K; k_floor += 8) {
#if 0
      // With larger in-thread accumulators, this path is faster.
      outer_loop(A, B, A_value, B_value, A_index, B_index, 0, N, K);
      A_index.x += 8;
      B_index.y += 8;
#else
      // But in most situations, this one is instead.
      outer_loop(A_src, B_src, A_value, B_value, uint2(0), uint2(0), 0, N, K);
      A_src += 8;
      B_src += N * 8;
#endif
      
      inner_loop(A_value, B_value, C_value);
    }
  }

  // Store the accumulator.
  if (simdgroup_use_threadgroup) {
    C_index += uint2(sid * sb_size_mn);
  }
#if 1
  // This is faster for very small FP16 or very large accumulators.
  store_acc(C, C_value, C_index, N);
#else
  // This is faster most often, and with some large FP32 accumulators.
  auto _C = C + C_index.y * N + C_index.x;
  store_acc(_C, C_value, uint2(0), N);
#endif
}

// User must allocate threadgroup memory at runtime.
// bytes = 2 * (threadgroup_block_mn + 0) * threadgroup_block_k * sizeof(real)
// If not using threadgroups, set it to 16 bytes.
//
// If using threadgroups, `threadgroup_block_mn` will be ignored and replaced
// with `simdgroup_block_mn` * 2. You should still set the parameter to the
// correct value as good practice, and the API states that not doing so causes
// undefined behavior.
template <typename real>
void _gemm_simdgroup_common
 (
  device real* A,
  device real* B,
  device real* C,
  threadgroup real* blocks,
  uint2 gid,
  ushort2 gsize,
  ushort sidx)
{
  ushort2 simds_per_group(gsize.x / 32, gsize.y);
  
  // `amp_gid` = amplified group ID
  uint2 amp_gid = gid * uint2(simds_per_group * simdgroup_block_mn);
  ushort2 sid = 0;
  if (simdgroup_use_threadgroup) {
    // Threadgroup memory variant requires 128-threadgroups.
    sid = ushort2(sidx % 2, sidx / 2);
  } else {
    switch (as_type<uint>(simds_per_group)) {
#if 0
      // Support 256-threadgroups.
      case as_type<uint>(ushort2(8, 1)): {
        sid = ushort2(sidx, 0);
        break;
      }
      case as_type<uint>(ushort2(4, 2)): {
        sid = ushort2(sidx % 4, sidx / 4);
        break;
      }
      case as_type<uint>(ushort2(2, 4)): {
        sid = ushort2(sidx % 2, sidx / 2);
        break;
      }
      case as_type<uint>(ushort2(1, 8)): {
        sid = ushort2(0, sidx);
        break;
      }
#endif
      case as_type<uint>(ushort2(4, 1)): {
        sid = ushort2(sidx, 0);
        break;
      }
      case as_type<uint>(ushort2(2, 2)): {
        sid = ushort2(sidx % 2, sidx / 2);
        break;
      }
      case as_type<uint>(ushort2(1, 4)): {
        sid = ushort2(0, sidx);
        break;
      }
    }
    amp_gid += uint2(sid * simdgroup_block_mn);
  }
  uint2 A_index(0, amp_gid.y);
  uint2 B_index(amp_gid.x, 0);
  uint2 C_index(amp_gid.x, amp_gid.y);
  
  // A_block[threadgroup_block_mn][threadgroup_block_k]
  // B_block[threadgroup_block_k][threadgroup_block_mn]
  const ushort threadgroup_block_mn = 2 * simdgroup_block_mn;
  threadgroup real* A_block = blocks;
  threadgroup real* B_block = blocks;
  B_block += (threadgroup_block_mn + 0) * threadgroup_block_k;
  
  if (simdgroup_block_mn == 8) {
    _gemm_simdgroup<real, 8, 16,
    _gemm_simdgroup_init_acc8x8,
    _gemm_simdgroup_cache8x8<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop8x8<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop8x8<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop8x8,
    _gemm_simdgroup_store_acc8x8
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  } else if (simdgroup_block_mn == 16) {
    _gemm_simdgroup<real, 16, 32,
    _gemm_simdgroup_init_acc16x16,
    _gemm_simdgroup_cache16x16<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop16x16<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop16x16<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop16x16,
    _gemm_simdgroup_store_acc16x16
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  } else if (simdgroup_block_mn == 24) {
    _gemm_simdgroup<real, 24, 48,
    _gemm_simdgroup_init_acc24x24,
    _gemm_simdgroup_cache24x24<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop24x24<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop24x24<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop24x24,
    _gemm_simdgroup_store_acc24x24
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  } else if (simdgroup_block_mn == 32) {
    _gemm_simdgroup<real, 32, 64,
    _gemm_simdgroup_init_acc32x32,
    _gemm_simdgroup_cache32x32<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop32x32<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop32x32<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop32x32,
    _gemm_simdgroup_store_acc32x32
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  } else if (simdgroup_block_mn == 40) {
    _gemm_simdgroup<real, 40, 80,
    _gemm_simdgroup_init_acc40x40,
    _gemm_simdgroup_cache40x40<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop40x40<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop40x40<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop40x40,
    _gemm_simdgroup_store_acc40x40
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  } else if (simdgroup_block_mn == 48) {
    _gemm_simdgroup<real, 48, 96,
    _gemm_simdgroup_init_acc48x48,
    _gemm_simdgroup_cache48x48<real, threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop48x48<real, const threadgroup real*, ushort>,
    _gemm_simdgroup_outer_loop48x48<real, const device real*, uint>,
    _gemm_simdgroup_inner_loop48x48,
    _gemm_simdgroup_store_acc48x48
    >(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
  }
}

// MARK: - Declaration of Regular GEMM

template <typename real>
kernel void gemm_simdgroup
 (
  device real* A [[buffer(0)]],
  device real* B [[buffer(1)]],
  device real* C [[buffer(2)]],
  threadgroup real* blocks [[threadgroup(0)]],
  uint2 gid [[threadgroup_position_in_grid]],
  ushort2 gsize [[threads_per_threadgroup]],
  ushort sidx [[simdgroup_index_in_threadgroup]])
{
  _gemm_simdgroup_common(A, B, C, blocks, gid, gsize, sidx);
}

template [[host_name("gemm_simdgroup_f16")]]
kernel void gemm_simdgroup
(
 device half*, device half*, device half*,
 threadgroup half*, uint2, ushort2, ushort);

template [[host_name("gemm_simdgroup_f32")]]
kernel void gemm_simdgroup
(
 device float*, device float*, device float*,
 threadgroup float*, uint2, ushort2, ushort);

// MARK: - Declaration of Batched GEMM

template <typename real>
kernel void gemm_simdgroup_batched
 (
  device real* A [[buffer(0)]],
  device real* B [[buffer(1)]],
  device real* C [[buffer(2)]],
  threadgroup real* blocks [[threadgroup(0)]],
  uint3 gid [[threadgroup_position_in_grid]],
  ushort3 gsize [[threads_per_threadgroup]],
  ushort sidx [[simdgroup_index_in_threadgroup]])
{
  auto _A = A + gid.z * A_stride;
  auto _B = B + gid.z * B_stride;
  auto _C = C + gid.z * C_stride;
  _gemm_simdgroup_common(_A, _B, _C, blocks, gid.xy, gsize.xy, sidx);
}

template [[host_name("gemm_simdgroup_batched_f16")]]
kernel void gemm_simdgroup_batched
(
 device half*, device half*, device half*,
 threadgroup half*, uint3, ushort3, ushort);

template [[host_name("gemm_simdgroup_batched_f32")]]
kernel void gemm_simdgroup_batched
(
 device float*, device float*, device float*,
 threadgroup float*, uint3, ushort3, ushort);
#endif


// MARK: - SIMD-group GEMM

template <typename real, ushort sb_size_mn>
struct simdgroup_accumulator {
  simdgroup_matrix<real, 8, 8> data[sb_size_mn / 8][sb_size_mn / 8];
};

// TODO: Create template functions for 'simdgroup_use_threadgroup'.
// This might even run fastest if you immediately page it from device -> thread
// -> threadgroup. In that case, the name would be very different.
// _gemm_simdgroup_store_threadgroup

template <typename real>
void _gemm_simdgroup_init_acc8x8(thread simdgroup_accumulator<real, 8>& C) {
  C.data[0][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
}

template <typename real>
void _gemm_simdgroup_init_acc16x16(thread simdgroup_accumulator<real, 16>& C) {
  C.data[0][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
  C.data[0][1] = make_filled_simdgroup_matrix<real, 8, 8>(0);
  
  C.data[1][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
  C.data[1][1] = make_filled_simdgroup_matrix<real, 8, 8>(0);
}

template <typename real>
void _gemm_simdgroup_init_acc24x24(thread simdgroup_accumulator<real, 24>& C) {
#define ROW_24(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \

  ROW_24(0)
  ROW_24(1)
  ROW_24(2)
#undef ROW_24
}

template <typename real>
void _gemm_simdgroup_init_acc32x32(thread simdgroup_accumulator<real, 32>& C) {
#define ROW_32(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \

  ROW_32(0)
  ROW_32(1)
  ROW_32(2)
  ROW_32(3)
#undef ROW_32
}

template <typename real>
void _gemm_simdgroup_init_acc40x40(thread simdgroup_accumulator<real, 40>& C) {
#define ROW_40(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][4] = make_filled_simdgroup_matrix<real, 8, 8>(0); \

  ROW_40(0)
  ROW_40(1)
  ROW_40(2)
  ROW_40(3)
  ROW_40(4)
#undef ROW_40
}

template <typename real>
void _gemm_simdgroup_init_acc48x48(thread simdgroup_accumulator<real, 48>& C) {
#define ROW_48(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][4] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][5] = make_filled_simdgroup_matrix<real, 8, 8>(0); \

  ROW_48(0)
  ROW_48(1)
  ROW_48(2)
  ROW_48(3)
  ROW_48(4)
  ROW_48(5)
#undef ROW_48
}

#define BEGIN_STORE_ACC(size) \
template <typename real> \
void _gemm_simdgroup_store_acc##size##x##size \
 ( \
  device real* C, \
  simdgroup_accumulator<real, size> C_value, \
  uint2 C_index, \
  uint _N) \
{ \

#define __STORE simdgroup_store

#define STORE_ACC(m, n) \
__STORE(C_value.data[m / 8][n / 8], C, _N, ulong2(C_index + uint2(n, m)));

#define END_STORE_ACC \
} \

BEGIN_STORE_ACC(8)
  STORE_ACC(0, 0)
END_STORE_ACC

BEGIN_STORE_ACC(16)
  STORE_ACC(0, 0)
  STORE_ACC(0, 8)
  
  STORE_ACC(8, 0)
  STORE_ACC(8, 8)
END_STORE_ACC

BEGIN_STORE_ACC(24)
#define ROW_24(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
  
  ROW_24(0)
  ROW_24(8)
  ROW_24(16)
#undef ROW_24
END_STORE_ACC

BEGIN_STORE_ACC(32)
#define ROW_32(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
  
  ROW_32(0)
  ROW_32(8)
  ROW_32(16)
  ROW_32(24)
#undef ROW_32
END_STORE_ACC

BEGIN_STORE_ACC(40)
#define ROW_40(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
STORE_ACC(m, 32) \
  
  ROW_40(0)
  ROW_40(8)
  ROW_40(16)
  ROW_40(24)
  ROW_40(32)
#undef ROW_40
END_STORE_ACC

BEGIN_STORE_ACC(48)
#define ROW_48(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
STORE_ACC(m, 32) \
STORE_ACC(m, 40) \
  
  ROW_48(0)
  ROW_48(8)
  ROW_48(16)
  ROW_48(24)
  ROW_48(32)
  ROW_48(40)
#undef ROW_48
END_STORE_ACC

#undef BEGIN_STORE_ACC
#undef __STORE
#undef STORE_ACC
#undef END_STORE_ACC

// TODO: Wrap the duplicated boiler plate in a macro.

#define BEGIN_OUTER(name, size) \
template < \
typename real, \
typename const_real_ptr, \
typename index_type, \
typename index_type2 = vec<index_type, 2> \
> \
void _gemm_simdgroup_##name##size##x##size \
 ( \
  const_real_ptr A, \
  const_real_ptr B, \
  thread simdgroup_matrix<real, 8, 8>* A_value, \
  thread simdgroup_matrix<real, 8, 8>* B_value, \
  index_type2 A_index, \
  index_type2 B_index, \
  index_type k, \
  index_type _N, \
  index_type _K) \

#define DEFINE_8x8(name, function) \
BEGIN_OUTER(name, 8) \
{ \
  function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
  function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
} \

DEFINE_8x8(outer_loop, simdgroup_load)
DEFINE_8x8(cache, simdgroup_store)
#undef DEFINE_8x8

// Load everything from B before starting A, so the first MACC
// iterations can start before A finishes loading.
//
// ?????????????????????????????????????????????????????????????
//
// Doing that makes it slightly slower.
#define DEFINE_16x16(name, function) \
BEGIN_OUTER(name, 16) \
{ \
  function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
  function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
 \
  function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
  function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
} \

DEFINE_16x16(outer_loop, simdgroup_load)
DEFINE_16x16(cache, simdgroup_store)
#undef DEFINE_16x16

// Not unrolling makes it faster - ???
#define DEFINE_24x24(name, function) \
BEGIN_OUTER(name, 24) \
{ \
  for (ushort mn = 0; mn < 24; mn += 8) { \
    auto _mn = mn / 8; \
    typedef index_type2 id_ty2; \
    function(A_value[_mn], A, _K, ulong2(A_index + id_ty2(k, mn)), 0); \
    function(B_value[_mn], B, _N, ulong2(B_index + id_ty2(mn, k)), 0); \
  } \
} \

DEFINE_24x24(outer_loop, simdgroup_load)
DEFINE_24x24(cache, simdgroup_store)
#undef DEFINE_24x24

// Why is this order faster now?
#define DEFINE_32x32(name, function) \
BEGIN_OUTER(name, 32) \
{ \
  function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
  function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
  function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
  function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
  function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
  function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
  function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
  function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
} \

DEFINE_32x32(outer_loop, simdgroup_load)
DEFINE_32x32(cache, simdgroup_store)
#undef DEFINE_32x32

// Not sure whether it's noise, but this feels more consistent too.
#define DEFINE_40x40(name, function) \
BEGIN_OUTER(name, 40) \
{ \
  function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
  function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
  function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
  function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
  function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
  function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
  function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
  function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
  function(A_value[4], A, _K, ulong2(A_index + index_type2(k, 32)), 0); \
  function(B_value[4], B, _N, ulong2(B_index + index_type2(32, k)), 0); \
} \

DEFINE_40x40(outer_loop, simdgroup_load)
DEFINE_40x40(cache, simdgroup_store)
#undef DEFINE_40x40

// Not sure whether it's just luck, but this configuration produced the
// highest value ever recorded.
#define DEFINE_48x48(name, function) \
BEGIN_OUTER(name, 48) \
{ \
  function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
  function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
  function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
  function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
  function(A_value[4], A, _K, ulong2(A_index + index_type2(k, 32)), 0); \
  function(A_value[5], A, _K, ulong2(A_index + index_type2(k, 40)), 0); \
  \
  function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
  function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
  function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
  function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
  function(B_value[4], B, _N, ulong2(B_index + index_type2(32, k)), 0); \
  function(B_value[5], B, _N, ulong2(B_index + index_type2(40, k)), 0); \
} \

DEFINE_48x48(outer_loop, simdgroup_load)
DEFINE_48x48(cache, simdgroup_store)
#undef DEFINE_48x48

#undef BEGIN_OUTER

#define BEGIN_INNER(size) \
template <typename real> \
void _gemm_simdgroup_inner_loop##size##x##size \
 ( \
  const thread simdgroup_matrix<real, 8, 8>* A, \
  const thread simdgroup_matrix<real, 8, 8>* B, \
  thread simdgroup_accumulator<real, size>& C) \
{ \

#define MACC(a, b, c) \
simdgroup_multiply_accumulate(c, a, b, c); \

#define END_INNER \
} \

BEGIN_INNER(8)
  MACC(A[0], B[0], C.data[0][0]);
END_INNER

BEGIN_INNER(16)
#define ROW_16(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \

  ROW_16(0)
  ROW_16(1)
#undef ROW_16
END_INNER

BEGIN_INNER(24)
#define ROW_24(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \

  ROW_24(0)
  ROW_24(1)
  ROW_24(2)
#undef ROW_24
END_INNER

BEGIN_INNER(32)
#define ROW_32(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \

  ROW_32(0)
  ROW_32(1)
  ROW_32(2)
  ROW_32(3)
#undef ROW_32
END_INNER

BEGIN_INNER(40)
#define ROW_40(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \
MACC(A[m], B[4], C.data[m][4]); \

  ROW_40(0)
  ROW_40(1)
  ROW_40(2)
  ROW_40(3)
  ROW_40(4)
#undef ROW_40
END_INNER

BEGIN_INNER(48)
#define ROW_48(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \
MACC(A[m], B[4], C.data[m][4]); \
MACC(A[m], B[5], C.data[m][5]); \

  ROW_48(0)
  ROW_48(1)
  ROW_48(2)
  ROW_48(3)
  ROW_48(4)
  ROW_48(5)
#undef ROW_48
END_INNER

#undef BEGIN_INNER
#undef MACC
#undef END_INNER
#endif // __METAL__

philipturner avatar May 26 '23 00:05 philipturner

I think I can get 8.7 TFLOPS.

Again not interested in the bounty though.

philipturner avatar May 30 '23 02:05 philipturner

Changes made in tinygrad/:

------------------------------------------------------------
files                             insertions       deletions
------------------------------------------------------------
tinygrad/codegen/cstyle.py                46               7
tinygrad/codegen/linearizer.py            87             104
tinygrad/ops.py                            4               3
tinygrad/runtime/ops_gpu.py               78               1
tinygrad/runtime/ops_metal.py             65               0
------------------------------------------------------------
total                                    280             115
------------------------------------------------------------
lines added in the tinygrad folder: 165

tinyb0t avatar Jun 01 '23 14:06 tinyb0t

I just got 9.92 TFLOPS (93.4% ALU utilization). Good luck getting this speed with JIT-compiled Metal shaders!

philipturner avatar Jun 01 '23 20:06 philipturner

FP16 or FP32? I think it might be doable FP16, but FP32 has too much register pressure.

geohot avatar Jun 03 '23 16:06 geohot

FP32 has too much register pressure.

Exceeded 9 TFLOPS at both precisions.

Large Square Sizes 256 x 256 x 256 384 x 384 x 384 512 x 512 x 512 768 x 768 x 768 1024 x 1024 x 1024 1280 x 1280 x 1280 1440 x 1440 x 1440
Accelerate F64 333 622 616 696 442
MFA F64* <590 <590 <590 <590 <590 <590 <590
Accelerate F32 1223 1303 2282 2679 2262
MPS F32 1847 3216 6200 6157 8153 7771 6497
MFA F32 1550 3131 4894 7739 8185 8111 8472
MPS F16 1730 4066 5849 5680 7336 7102 6433
MFA F16 2133 3662 5372 8525 9353 9109 9215
Large Square Sizes 2048 x 2048 3072 x 3072 4096 x 4096 6144 x 6144
Accelerate F64 536 516 520 504
MFA F64* <590 <590 <590 <590
Accelerate F32 1978 2058 1957 1947
MPS F32 8472 8482 8270 Error
MFA F32 8992 9236 9247 9257
MPS F16 7729 7824 7771 Error
MFA F16 9618 9788 9778 9905

philipturner avatar Jun 03 '23 16:06 philipturner

@philipturner got M1 tensor cores merged, though not quite getting 9 TFLOPS :)

are you using async copies? the 4 vs 6 shape thing the searcher will find when it's ready.

geohot avatar Jul 09 '23 16:07 geohot

4 vs 6 shape is split across a threadgroup of 128 threads. Your draft was only scoped around a simd (32 threads).

Block of Matrix: Address Space M N K
A threadgroup 32-48 24-40
B threadgroup 32-48 24-40
A thread 16-24 8
B thread 16-24 8
C thread 16-24 16-24

Threadgroup layout:

B (N=0-24) B (N=24-48)
A (M=0-24) simd 0 simd 1
A (M=24-48) simd 2 simd 3

The first simd async copies data from device -> threadgroup and immediately waits. Then, the entire threadgroup immediately does a threadgroup_barrier. So it's not actually asynchronous, just decreases the number of instructions.

The async copy engine can perform zero-padding, but it's extremely expensive. All four simds fetch threadgroup -> thread, padded along K to a multiple of 8. M and N are not padded, potentially containing garbage data.

Virtual GEMM (threadgroup scoped)

device KxN
threadgroup 24x48
device MxK threadgroup 48x24 48x48 accumulator

Virtual GEMM (simd scoped)

threadgroup 24x48
thread 8x24
threadgroup 48x24 thread 24x8 24x24 accumulator

Your draft (simd scoped)

device KxN
thread 8x32
device MxK thread 32x8 32x32 accumulator

philipturner avatar Jul 09 '23 18:07 philipturner

Async copies are unavailable to the runtime MSL compiler. You don't need codegen at the MSL level anyway. Rather, 500 lines of MSL, which is transformed into AIR offline. At runtime, codegen happens during the AIR -> assembly transformation via MTLFunctionConstantValues. I pre-compiled the AIR and hosted it here.

For non-GEMM parts, MSL codegen could be useful. My metallib supports fused activations through MTLVisibleFunction. Would this API suffice for Tinygrad?

philipturner avatar Jul 09 '23 18:07 philipturner