Use the Tensor Cores on M1+ ($1000 bounty)
Claim the bounty by implementing this and having tinygrad generate GEMM kernels for M1 that are faster than torch/Accelerate.framework.
Clean code only, must be merged to claim bounty.
I’m working on a library that will make this task easy. Even easy for you to do yourself. I have a different priority at the moment but plan to implement Stream-K in a few weeks.
I have it outperforming MPS at around half of my test cases (notably FP16). If you want to get a head start, I can post some of the raw source code from my repo, and a few tables of perf data.
I’m not interested in the bounty just trying to prioritize my own interests.
The value in the bounty for me is the integration into tinygrad. Though there's definitely value in open implementations with great perf numbers for M1, considering Accelerate is closed source. Are you using the async copy stuff?
Are you using the async copy stuff?
I couldn't find it, and it's probably not needed. I suspect that the simdgroup matrix load/store API calls that under the hood. Seems like it utilizes texture hardware (arbitrary image width, texture pointer + sample position).
The value in the bounty for me is the integration into tinygrad.
Basically, my library will compile into a .metallib. It's designed for very easy integration + low overhead. It should support not just Apple 7, but also AMD and Intel. I'm waiting to open-source it until I'm ready. You seem to be the only other person with interest in the subject, hence why I'm talking on GitHub.
I got this by fine-tuning the heck out of block sizes. Stream-K should make it possible with a single block size.
| Large Square Sizes | 256 x 256 x 256 | 384 x 384 x 384 | 512 x 512 x 512 | 768 x 768 x 768 | 1024 x 1024 x 1024 | 1280 x 1280 x 1280 | 1440 x 1440 x 1440 |
|---|---|---|---|---|---|---|---|
| Accelerate F64 | 333 | 622 | 616 | 696 | 442 | ||
| MFA F64* | <590 | <590 | <590 | <590 | <590 | <590 | <590 |
| Accelerate F32 | 1223 | 1303 | 2282 | 2679 | 2262 | ||
| MPS F32 | 1836 | 3216 | 6200 | 6157 | 8143 | 7771 | 6497 |
| MFA F32 | 1327 | 2632 | 3960 | 6656 | 6879 | 6752 | 7017 |
| MPS F16 | 1730 | 4066 | 5849 | 5680 | 7336 | 7102 | 6433 |
| MFA F16 | 2133 | 2590 | 4225 | 7569 | 8185 | 8058 | 8047 |
| Large Square Sizes | 256 x 256 x 256 | 384 x 384 x 384 | 512 x 512 x 512 | 768 x 768 x 768 | 1024 x 1024 x 1024 | 1280 x 1280 x 1280 | 1440 x 1440 x 1440 |
|---|---|---|---|---|---|---|---|
| Faster than MPS? | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |
| MPSGraph F16 | 16.3% | 38.3% | 55.1% | 53.5% | 69.1% | 66.9% | 60.6% |
| MPSGraph F32 | 17.3% | 30.3% | 58.4% | 58.0% | 76.8% | 73.2% | 61.2% |
naive_f16 |
4.07% | 13.6% | 14.2% | 15.0% | 14.9% | 15.1% | 15.3% |
naive_f32 |
6.69% | 9.50% | 12.8% | 13.0% | 13.4% | 11.3% | 12.2% |
blocked_f16 |
16.0% | 23.6% | 37.2% | 55.5% | 61.7% | 55.4% | 59.3% |
blocked_f32 |
4.87% | 17.4% | 18.8% | 30.6% | 31.4% | 28.0% | 30.8% |
threadgroup_f16 |
5.50% | 16.4% | 25.2% | 50.5% | 51.5% | 58.9% | 61.3% |
threadgroup_f32 |
4.71% | 11.2% | 25.0% | 30.5% | 32.3% | 32.4% | 32.4% |
simdgroup_f16 |
20.1% | 24.4% | 39.8% | 71.3% | 77.1% | 75.9% | 75.8% |
simdgroup_f32 |
12.5% | 24.8% | 37.3% | 62.7% | 64.8% | 63.6% | 66.1% |
streamk_f16 |
|||||||
streamk_f32 |
Which chip, and what are the units?
Which chip, and what are the units?
32-core M1 Max. The first is GFLOPS (FP64 emulation is just a theoretical upper bound). The second is GPU ALU%. Use the flops from AppleGPUInfo not the 10.4 that Apple says.
Do what you will with this. Using threadgroup memory inside gemm_simdgroup made it worse.
gemm_simdgroup
// Dimensions of the matrices.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Stride for matrices during batched GEMM, in number of elements.
constant ulong A_stride [[function_constant(3)]];
constant ulong B_stride [[function_constant(4)]];
constant ulong C_stride [[function_constant(5)]];
// Sometimes abbreviated `tb_size`.
constant ushort thread_block_mn [[function_constant(10)]];
constant ushort thread_block_k = 1;
// Sometimes abbreviated `sb_size`.
constant ushort simdgroup_block_mn [[function_constant(12)]];
constant ushort simdgroup_block_k = 8;
// Sometimes abbreviated `tgb_size`.
constant ushort threadgroup_block_mn [[function_constant(14)]];
constant ushort threadgroup_block_k [[function_constant(15)]];
// Threads/Threadgroup = pow(`threadgroup_size`, 2)
// Except for `gemm_simdgroup`, where threadgroup size is always 32 or 128.
constant ushort threadgroup_size = threadgroup_block_mn / thread_block_mn;
// Whether to use threadgroup memory in `gemm_simdgroup`. This may be necessary
// for some unaligned matrix sizes (try to make it not necessary by exploiting
// knowledge about the layout pattern):
// https://patentimages.storage.googleapis.com/88/6d/9a/510adee9164f8f/US11256518.pdf
//
// TODO: Create separate load/store execution path for edges of matrices.
constant bool simdgroup_use_threadgroup [[function_constant(20)]];
//
// GEMM_simdgroup.cpp
// MetalFlashAttention
//
// Created by Philip Turner on 5/18/23.
//
#include "../Common/Common.hpp"
#include "GEMM.hpp"
// Implementation of `gemm_simdgroup`
#ifdef __METAL__
// If not using threadgroups, then append `sid` to A, B, and C beforehand. This
// will now accept threadgroups with (1, 4), (2, 2) or (4, 1) shapes, and ignore
// the `sid` parameter.
//
// WARNING: Do not use `threadgroup_block_mn` in this loop.
template <
typename real,
ushort sb_size_mn,
ushort tgb_size_mn,
void init_acc(thread simdgroup_accumulator<real, sb_size_mn>&),
void write_threadgroup(threadgroup real*,
threadgroup real*,
thread simdgroup_matrix<real, 8, 8>*,
thread simdgroup_matrix<real, 8, 8>*,
ushort2, ushort2, ushort, ushort, ushort),
void read_threadgroup(const threadgroup real*,
const threadgroup real*,
thread simdgroup_matrix<real, 8, 8>*,
thread simdgroup_matrix<real, 8, 8>*,
ushort2, ushort2, ushort, ushort, ushort),
void outer_loop(const device real*,
const device real*,
thread simdgroup_matrix<real, 8, 8>*,
thread simdgroup_matrix<real, 8, 8>*,
uint2, uint2, uint, uint, uint),
void inner_loop(const thread simdgroup_matrix<real, 8, 8>*,
const thread simdgroup_matrix<real, 8, 8>*,
thread simdgroup_accumulator<real, sb_size_mn>&),
void store_acc(device real*,
simdgroup_accumulator<real, sb_size_mn>,
uint2, uint)
>
void _gemm_simdgroup
(
device real* A, device real* B, device real* C,
threadgroup real* A_block, threadgroup real* B_block,
thread uint2& A_index, thread uint2& B_index, thread uint2& C_index,
ushort2 sid)
{
// Do not use `threadgroup_block_mn` in the loop!
typedef void threadgroup_block_mn;
// Accumulator addressed by [y][x]
// SIMD load addressed by [x][y]
// m = y
// n = x
typedef simdgroup_matrix<real, 8, 8> simdgroup_real8x8;
simdgroup_real8x8 A_value[sb_size_mn / 8];
simdgroup_real8x8 B_value[sb_size_mn / 8];
simdgroup_accumulator<real, sb_size_mn> C_value;
// Initialize the accumulator.
init_acc(C_value);
if (simdgroup_use_threadgroup) {
A_index.y += sid.x * sb_size_mn;
B_index.x += sid.x * sb_size_mn;
// Fusing the pointer with its index consumes some extra registers. It is
// faster at smaller block sizes because there's less register pressure.
#define FUSE_DEVICE 1
#if FUSE_DEVICE
auto A_src = A + A_index.y * K;
auto B_src = B + B_index.x;
#endif
for (uint k_floor = 0; k_floor < K; k_floor += threadgroup_block_k) {
auto A_dst = A_block + sid.x * sb_size_mn * threadgroup_block_k;
auto B_dst = B_block + sid.x * sb_size_mn;
for (ushort k = sid.y * 8; k < threadgroup_block_k; k += 16) {
// Temporarily store fetched tiles in SIMD matrices.
#if FUSE_DEVICE
outer_loop(A_src, B_src, A_value, B_value, uint2(0), uint2(0), k, N, K);
#else
outer_loop(A, B, A_value, B_value, A_index, B_index, k, N, K);
#endif
// Cache in threadgroup memory.
#if 0
ushort2 A_block_index(0, sid.x * sb_size_mn);
ushort2 B_block_index(sid.x * sb_size_mn, 0);
write_threadgroup
(
A_block, B_block, A_value, B_value, A_block_index, B_block_index,
k, tgb_size_mn, threadgroup_block_k);
#else
write_threadgroup
(
A_dst, B_dst, A_value, B_value, ushort2(0), ushort2(0),
k, tgb_size_mn, threadgroup_block_k);
#endif
}
threadgroup_barrier(mem_flags::mem_threadgroup);
auto A_cache = A_block + sid.y * sb_size_mn * threadgroup_block_k;
auto B_cache = B_block + sid.x * sb_size_mn;
for (ushort k = 0; k < threadgroup_block_k; k += 8) {
// Read from lower-latency cache.
#if 0
ushort2 A_block_index(0, sid.y * sb_size_mn);
ushort2 B_block_index(sid.x * sb_size_mn, 0);
read_threadgroup
(
A_block, B_block, A_value, B_value, A_block_index, B_block_index,
k, tgb_size_mn, threadgroup_block_k);
#else
read_threadgroup
(
A_cache, B_cache, A_value, B_value, ushort2(0), ushort2(0),
k, tgb_size_mn, threadgroup_block_k);
#endif
inner_loop(A_value, B_value, C_value);
}
#if FUSE_DEVICE
A_src += threadgroup_block_k;
B_src += threadgroup_block_k * N;
#else
A_index.x += threadgroup_block_k;
B_index.y += threadgroup_block_k;
#endif
threadgroup_barrier(mem_flags::mem_none);
}
#undef FUSE_DEVICE
} else {
auto A_src = A + A_index.y * K + A_index.x;
auto B_src = B + B_index.y * N + B_index.x;
for (uint k_floor = 0; k_floor < K; k_floor += 8) {
#if 0
// With larger in-thread accumulators, this path is faster.
outer_loop(A, B, A_value, B_value, A_index, B_index, 0, N, K);
A_index.x += 8;
B_index.y += 8;
#else
// But in most situations, this one is instead.
outer_loop(A_src, B_src, A_value, B_value, uint2(0), uint2(0), 0, N, K);
A_src += 8;
B_src += N * 8;
#endif
inner_loop(A_value, B_value, C_value);
}
}
// Store the accumulator.
if (simdgroup_use_threadgroup) {
C_index += uint2(sid * sb_size_mn);
}
#if 1
// This is faster for very small FP16 or very large accumulators.
store_acc(C, C_value, C_index, N);
#else
// This is faster most often, and with some large FP32 accumulators.
auto _C = C + C_index.y * N + C_index.x;
store_acc(_C, C_value, uint2(0), N);
#endif
}
// User must allocate threadgroup memory at runtime.
// bytes = 2 * (threadgroup_block_mn + 0) * threadgroup_block_k * sizeof(real)
// If not using threadgroups, set it to 16 bytes.
//
// If using threadgroups, `threadgroup_block_mn` will be ignored and replaced
// with `simdgroup_block_mn` * 2. You should still set the parameter to the
// correct value as good practice, and the API states that not doing so causes
// undefined behavior.
template <typename real>
void _gemm_simdgroup_common
(
device real* A,
device real* B,
device real* C,
threadgroup real* blocks,
uint2 gid,
ushort2 gsize,
ushort sidx)
{
ushort2 simds_per_group(gsize.x / 32, gsize.y);
// `amp_gid` = amplified group ID
uint2 amp_gid = gid * uint2(simds_per_group * simdgroup_block_mn);
ushort2 sid = 0;
if (simdgroup_use_threadgroup) {
// Threadgroup memory variant requires 128-threadgroups.
sid = ushort2(sidx % 2, sidx / 2);
} else {
switch (as_type<uint>(simds_per_group)) {
#if 0
// Support 256-threadgroups.
case as_type<uint>(ushort2(8, 1)): {
sid = ushort2(sidx, 0);
break;
}
case as_type<uint>(ushort2(4, 2)): {
sid = ushort2(sidx % 4, sidx / 4);
break;
}
case as_type<uint>(ushort2(2, 4)): {
sid = ushort2(sidx % 2, sidx / 2);
break;
}
case as_type<uint>(ushort2(1, 8)): {
sid = ushort2(0, sidx);
break;
}
#endif
case as_type<uint>(ushort2(4, 1)): {
sid = ushort2(sidx, 0);
break;
}
case as_type<uint>(ushort2(2, 2)): {
sid = ushort2(sidx % 2, sidx / 2);
break;
}
case as_type<uint>(ushort2(1, 4)): {
sid = ushort2(0, sidx);
break;
}
}
amp_gid += uint2(sid * simdgroup_block_mn);
}
uint2 A_index(0, amp_gid.y);
uint2 B_index(amp_gid.x, 0);
uint2 C_index(amp_gid.x, amp_gid.y);
// A_block[threadgroup_block_mn][threadgroup_block_k]
// B_block[threadgroup_block_k][threadgroup_block_mn]
const ushort threadgroup_block_mn = 2 * simdgroup_block_mn;
threadgroup real* A_block = blocks;
threadgroup real* B_block = blocks;
B_block += (threadgroup_block_mn + 0) * threadgroup_block_k;
if (simdgroup_block_mn == 8) {
_gemm_simdgroup<real, 8, 16,
_gemm_simdgroup_init_acc8x8,
_gemm_simdgroup_cache8x8<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop8x8<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop8x8<real, const device real*, uint>,
_gemm_simdgroup_inner_loop8x8,
_gemm_simdgroup_store_acc8x8
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
} else if (simdgroup_block_mn == 16) {
_gemm_simdgroup<real, 16, 32,
_gemm_simdgroup_init_acc16x16,
_gemm_simdgroup_cache16x16<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop16x16<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop16x16<real, const device real*, uint>,
_gemm_simdgroup_inner_loop16x16,
_gemm_simdgroup_store_acc16x16
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
} else if (simdgroup_block_mn == 24) {
_gemm_simdgroup<real, 24, 48,
_gemm_simdgroup_init_acc24x24,
_gemm_simdgroup_cache24x24<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop24x24<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop24x24<real, const device real*, uint>,
_gemm_simdgroup_inner_loop24x24,
_gemm_simdgroup_store_acc24x24
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
} else if (simdgroup_block_mn == 32) {
_gemm_simdgroup<real, 32, 64,
_gemm_simdgroup_init_acc32x32,
_gemm_simdgroup_cache32x32<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop32x32<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop32x32<real, const device real*, uint>,
_gemm_simdgroup_inner_loop32x32,
_gemm_simdgroup_store_acc32x32
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
} else if (simdgroup_block_mn == 40) {
_gemm_simdgroup<real, 40, 80,
_gemm_simdgroup_init_acc40x40,
_gemm_simdgroup_cache40x40<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop40x40<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop40x40<real, const device real*, uint>,
_gemm_simdgroup_inner_loop40x40,
_gemm_simdgroup_store_acc40x40
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
} else if (simdgroup_block_mn == 48) {
_gemm_simdgroup<real, 48, 96,
_gemm_simdgroup_init_acc48x48,
_gemm_simdgroup_cache48x48<real, threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop48x48<real, const threadgroup real*, ushort>,
_gemm_simdgroup_outer_loop48x48<real, const device real*, uint>,
_gemm_simdgroup_inner_loop48x48,
_gemm_simdgroup_store_acc48x48
>(A, B, C, A_block, B_block, A_index, B_index, C_index, sid);
}
}
// MARK: - Declaration of Regular GEMM
template <typename real>
kernel void gemm_simdgroup
(
device real* A [[buffer(0)]],
device real* B [[buffer(1)]],
device real* C [[buffer(2)]],
threadgroup real* blocks [[threadgroup(0)]],
uint2 gid [[threadgroup_position_in_grid]],
ushort2 gsize [[threads_per_threadgroup]],
ushort sidx [[simdgroup_index_in_threadgroup]])
{
_gemm_simdgroup_common(A, B, C, blocks, gid, gsize, sidx);
}
template [[host_name("gemm_simdgroup_f16")]]
kernel void gemm_simdgroup
(
device half*, device half*, device half*,
threadgroup half*, uint2, ushort2, ushort);
template [[host_name("gemm_simdgroup_f32")]]
kernel void gemm_simdgroup
(
device float*, device float*, device float*,
threadgroup float*, uint2, ushort2, ushort);
// MARK: - Declaration of Batched GEMM
template <typename real>
kernel void gemm_simdgroup_batched
(
device real* A [[buffer(0)]],
device real* B [[buffer(1)]],
device real* C [[buffer(2)]],
threadgroup real* blocks [[threadgroup(0)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort3 gsize [[threads_per_threadgroup]],
ushort sidx [[simdgroup_index_in_threadgroup]])
{
auto _A = A + gid.z * A_stride;
auto _B = B + gid.z * B_stride;
auto _C = C + gid.z * C_stride;
_gemm_simdgroup_common(_A, _B, _C, blocks, gid.xy, gsize.xy, sidx);
}
template [[host_name("gemm_simdgroup_batched_f16")]]
kernel void gemm_simdgroup_batched
(
device half*, device half*, device half*,
threadgroup half*, uint3, ushort3, ushort);
template [[host_name("gemm_simdgroup_batched_f32")]]
kernel void gemm_simdgroup_batched
(
device float*, device float*, device float*,
threadgroup float*, uint3, ushort3, ushort);
#endif
// MARK: - SIMD-group GEMM
template <typename real, ushort sb_size_mn>
struct simdgroup_accumulator {
simdgroup_matrix<real, 8, 8> data[sb_size_mn / 8][sb_size_mn / 8];
};
// TODO: Create template functions for 'simdgroup_use_threadgroup'.
// This might even run fastest if you immediately page it from device -> thread
// -> threadgroup. In that case, the name would be very different.
// _gemm_simdgroup_store_threadgroup
template <typename real>
void _gemm_simdgroup_init_acc8x8(thread simdgroup_accumulator<real, 8>& C) {
C.data[0][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
}
template <typename real>
void _gemm_simdgroup_init_acc16x16(thread simdgroup_accumulator<real, 16>& C) {
C.data[0][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
C.data[0][1] = make_filled_simdgroup_matrix<real, 8, 8>(0);
C.data[1][0] = make_filled_simdgroup_matrix<real, 8, 8>(0);
C.data[1][1] = make_filled_simdgroup_matrix<real, 8, 8>(0);
}
template <typename real>
void _gemm_simdgroup_init_acc24x24(thread simdgroup_accumulator<real, 24>& C) {
#define ROW_24(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
ROW_24(0)
ROW_24(1)
ROW_24(2)
#undef ROW_24
}
template <typename real>
void _gemm_simdgroup_init_acc32x32(thread simdgroup_accumulator<real, 32>& C) {
#define ROW_32(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
ROW_32(0)
ROW_32(1)
ROW_32(2)
ROW_32(3)
#undef ROW_32
}
template <typename real>
void _gemm_simdgroup_init_acc40x40(thread simdgroup_accumulator<real, 40>& C) {
#define ROW_40(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][4] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
ROW_40(0)
ROW_40(1)
ROW_40(2)
ROW_40(3)
ROW_40(4)
#undef ROW_40
}
template <typename real>
void _gemm_simdgroup_init_acc48x48(thread simdgroup_accumulator<real, 48>& C) {
#define ROW_48(m) \
C.data[m][0] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][1] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][2] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][3] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][4] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
C.data[m][5] = make_filled_simdgroup_matrix<real, 8, 8>(0); \
ROW_48(0)
ROW_48(1)
ROW_48(2)
ROW_48(3)
ROW_48(4)
ROW_48(5)
#undef ROW_48
}
#define BEGIN_STORE_ACC(size) \
template <typename real> \
void _gemm_simdgroup_store_acc##size##x##size \
( \
device real* C, \
simdgroup_accumulator<real, size> C_value, \
uint2 C_index, \
uint _N) \
{ \
#define __STORE simdgroup_store
#define STORE_ACC(m, n) \
__STORE(C_value.data[m / 8][n / 8], C, _N, ulong2(C_index + uint2(n, m)));
#define END_STORE_ACC \
} \
BEGIN_STORE_ACC(8)
STORE_ACC(0, 0)
END_STORE_ACC
BEGIN_STORE_ACC(16)
STORE_ACC(0, 0)
STORE_ACC(0, 8)
STORE_ACC(8, 0)
STORE_ACC(8, 8)
END_STORE_ACC
BEGIN_STORE_ACC(24)
#define ROW_24(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
ROW_24(0)
ROW_24(8)
ROW_24(16)
#undef ROW_24
END_STORE_ACC
BEGIN_STORE_ACC(32)
#define ROW_32(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
ROW_32(0)
ROW_32(8)
ROW_32(16)
ROW_32(24)
#undef ROW_32
END_STORE_ACC
BEGIN_STORE_ACC(40)
#define ROW_40(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
STORE_ACC(m, 32) \
ROW_40(0)
ROW_40(8)
ROW_40(16)
ROW_40(24)
ROW_40(32)
#undef ROW_40
END_STORE_ACC
BEGIN_STORE_ACC(48)
#define ROW_48(m) \
STORE_ACC(m, 0) \
STORE_ACC(m, 8) \
STORE_ACC(m, 16) \
STORE_ACC(m, 24) \
STORE_ACC(m, 32) \
STORE_ACC(m, 40) \
ROW_48(0)
ROW_48(8)
ROW_48(16)
ROW_48(24)
ROW_48(32)
ROW_48(40)
#undef ROW_48
END_STORE_ACC
#undef BEGIN_STORE_ACC
#undef __STORE
#undef STORE_ACC
#undef END_STORE_ACC
// TODO: Wrap the duplicated boiler plate in a macro.
#define BEGIN_OUTER(name, size) \
template < \
typename real, \
typename const_real_ptr, \
typename index_type, \
typename index_type2 = vec<index_type, 2> \
> \
void _gemm_simdgroup_##name##size##x##size \
( \
const_real_ptr A, \
const_real_ptr B, \
thread simdgroup_matrix<real, 8, 8>* A_value, \
thread simdgroup_matrix<real, 8, 8>* B_value, \
index_type2 A_index, \
index_type2 B_index, \
index_type k, \
index_type _N, \
index_type _K) \
#define DEFINE_8x8(name, function) \
BEGIN_OUTER(name, 8) \
{ \
function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
} \
DEFINE_8x8(outer_loop, simdgroup_load)
DEFINE_8x8(cache, simdgroup_store)
#undef DEFINE_8x8
// Load everything from B before starting A, so the first MACC
// iterations can start before A finishes loading.
//
// ?????????????????????????????????????????????????????????????
//
// Doing that makes it slightly slower.
#define DEFINE_16x16(name, function) \
BEGIN_OUTER(name, 16) \
{ \
function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
\
function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
} \
DEFINE_16x16(outer_loop, simdgroup_load)
DEFINE_16x16(cache, simdgroup_store)
#undef DEFINE_16x16
// Not unrolling makes it faster - ???
#define DEFINE_24x24(name, function) \
BEGIN_OUTER(name, 24) \
{ \
for (ushort mn = 0; mn < 24; mn += 8) { \
auto _mn = mn / 8; \
typedef index_type2 id_ty2; \
function(A_value[_mn], A, _K, ulong2(A_index + id_ty2(k, mn)), 0); \
function(B_value[_mn], B, _N, ulong2(B_index + id_ty2(mn, k)), 0); \
} \
} \
DEFINE_24x24(outer_loop, simdgroup_load)
DEFINE_24x24(cache, simdgroup_store)
#undef DEFINE_24x24
// Why is this order faster now?
#define DEFINE_32x32(name, function) \
BEGIN_OUTER(name, 32) \
{ \
function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
} \
DEFINE_32x32(outer_loop, simdgroup_load)
DEFINE_32x32(cache, simdgroup_store)
#undef DEFINE_32x32
// Not sure whether it's noise, but this feels more consistent too.
#define DEFINE_40x40(name, function) \
BEGIN_OUTER(name, 40) \
{ \
function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
function(A_value[4], A, _K, ulong2(A_index + index_type2(k, 32)), 0); \
function(B_value[4], B, _N, ulong2(B_index + index_type2(32, k)), 0); \
} \
DEFINE_40x40(outer_loop, simdgroup_load)
DEFINE_40x40(cache, simdgroup_store)
#undef DEFINE_40x40
// Not sure whether it's just luck, but this configuration produced the
// highest value ever recorded.
#define DEFINE_48x48(name, function) \
BEGIN_OUTER(name, 48) \
{ \
function(A_value[0], A, _K, ulong2(A_index + index_type2(k, 0)), 0); \
function(A_value[1], A, _K, ulong2(A_index + index_type2(k, 8)), 0); \
function(A_value[2], A, _K, ulong2(A_index + index_type2(k, 16)), 0); \
function(A_value[3], A, _K, ulong2(A_index + index_type2(k, 24)), 0); \
function(A_value[4], A, _K, ulong2(A_index + index_type2(k, 32)), 0); \
function(A_value[5], A, _K, ulong2(A_index + index_type2(k, 40)), 0); \
\
function(B_value[0], B, _N, ulong2(B_index + index_type2(0, k)), 0); \
function(B_value[1], B, _N, ulong2(B_index + index_type2(8, k)), 0); \
function(B_value[2], B, _N, ulong2(B_index + index_type2(16, k)), 0); \
function(B_value[3], B, _N, ulong2(B_index + index_type2(24, k)), 0); \
function(B_value[4], B, _N, ulong2(B_index + index_type2(32, k)), 0); \
function(B_value[5], B, _N, ulong2(B_index + index_type2(40, k)), 0); \
} \
DEFINE_48x48(outer_loop, simdgroup_load)
DEFINE_48x48(cache, simdgroup_store)
#undef DEFINE_48x48
#undef BEGIN_OUTER
#define BEGIN_INNER(size) \
template <typename real> \
void _gemm_simdgroup_inner_loop##size##x##size \
( \
const thread simdgroup_matrix<real, 8, 8>* A, \
const thread simdgroup_matrix<real, 8, 8>* B, \
thread simdgroup_accumulator<real, size>& C) \
{ \
#define MACC(a, b, c) \
simdgroup_multiply_accumulate(c, a, b, c); \
#define END_INNER \
} \
BEGIN_INNER(8)
MACC(A[0], B[0], C.data[0][0]);
END_INNER
BEGIN_INNER(16)
#define ROW_16(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
ROW_16(0)
ROW_16(1)
#undef ROW_16
END_INNER
BEGIN_INNER(24)
#define ROW_24(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
ROW_24(0)
ROW_24(1)
ROW_24(2)
#undef ROW_24
END_INNER
BEGIN_INNER(32)
#define ROW_32(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \
ROW_32(0)
ROW_32(1)
ROW_32(2)
ROW_32(3)
#undef ROW_32
END_INNER
BEGIN_INNER(40)
#define ROW_40(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \
MACC(A[m], B[4], C.data[m][4]); \
ROW_40(0)
ROW_40(1)
ROW_40(2)
ROW_40(3)
ROW_40(4)
#undef ROW_40
END_INNER
BEGIN_INNER(48)
#define ROW_48(m) \
MACC(A[m], B[0], C.data[m][0]); \
MACC(A[m], B[1], C.data[m][1]); \
MACC(A[m], B[2], C.data[m][2]); \
MACC(A[m], B[3], C.data[m][3]); \
MACC(A[m], B[4], C.data[m][4]); \
MACC(A[m], B[5], C.data[m][5]); \
ROW_48(0)
ROW_48(1)
ROW_48(2)
ROW_48(3)
ROW_48(4)
ROW_48(5)
#undef ROW_48
END_INNER
#undef BEGIN_INNER
#undef MACC
#undef END_INNER
#endif // __METAL__
I think I can get 8.7 TFLOPS.
Again not interested in the bounty though.
Changes made in tinygrad/:
------------------------------------------------------------
files insertions deletions
------------------------------------------------------------
tinygrad/codegen/cstyle.py 46 7
tinygrad/codegen/linearizer.py 87 104
tinygrad/ops.py 4 3
tinygrad/runtime/ops_gpu.py 78 1
tinygrad/runtime/ops_metal.py 65 0
------------------------------------------------------------
total 280 115
------------------------------------------------------------
lines added in the tinygrad folder: 165
I just got 9.92 TFLOPS (93.4% ALU utilization). Good luck getting this speed with JIT-compiled Metal shaders!
FP16 or FP32? I think it might be doable FP16, but FP32 has too much register pressure.
FP32 has too much register pressure.
Exceeded 9 TFLOPS at both precisions.
| Large Square Sizes | 256 x 256 x 256 | 384 x 384 x 384 | 512 x 512 x 512 | 768 x 768 x 768 | 1024 x 1024 x 1024 | 1280 x 1280 x 1280 | 1440 x 1440 x 1440 |
|---|---|---|---|---|---|---|---|
| Accelerate F64 | 333 | 622 | 616 | 696 | 442 | ||
| MFA F64* | <590 | <590 | <590 | <590 | <590 | <590 | <590 |
| Accelerate F32 | 1223 | 1303 | 2282 | 2679 | 2262 | ||
| MPS F32 | 1847 | 3216 | 6200 | 6157 | 8153 | 7771 | 6497 |
| MFA F32 | 1550 | 3131 | 4894 | 7739 | 8185 | 8111 | 8472 |
| MPS F16 | 1730 | 4066 | 5849 | 5680 | 7336 | 7102 | 6433 |
| MFA F16 | 2133 | 3662 | 5372 | 8525 | 9353 | 9109 | 9215 |
| Large Square Sizes | 2048 x 2048 | 3072 x 3072 | 4096 x 4096 | 6144 x 6144 |
|---|---|---|---|---|
| Accelerate F64 | 536 | 516 | 520 | 504 |
| MFA F64* | <590 | <590 | <590 | <590 |
| Accelerate F32 | 1978 | 2058 | 1957 | 1947 |
| MPS F32 | 8472 | 8482 | 8270 | Error |
| MFA F32 | 8992 | 9236 | 9247 | 9257 |
| MPS F16 | 7729 | 7824 | 7771 | Error |
| MFA F16 | 9618 | 9788 | 9778 | 9905 |
@philipturner got M1 tensor cores merged, though not quite getting 9 TFLOPS :)
are you using async copies? the 4 vs 6 shape thing the searcher will find when it's ready.
4 vs 6 shape is split across a threadgroup of 128 threads. Your draft was only scoped around a simd (32 threads).
| Block of Matrix: | Address Space | M | N | K |
|---|---|---|---|---|
| A | threadgroup |
32-48 | 24-40 | |
| B | threadgroup |
32-48 | 24-40 | |
| A | thread |
16-24 | 8 | |
| B | thread |
16-24 | 8 | |
| C | thread |
16-24 | 16-24 |
Threadgroup layout:
| B (N=0-24) | B (N=24-48) | |
|---|---|---|
| A (M=0-24) | simd 0 | simd 1 |
| A (M=24-48) | simd 2 | simd 3 |
The first simd async copies data from device -> threadgroup and immediately waits. Then, the entire threadgroup immediately does a threadgroup_barrier. So it's not actually asynchronous, just decreases the number of instructions.
The async copy engine can perform zero-padding, but it's extremely expensive. All four simds fetch threadgroup -> thread, padded along K to a multiple of 8. M and N are not padded, potentially containing garbage data.
Virtual GEMM (threadgroup scoped)
device KxN |
||
|---|---|---|
threadgroup 24x48 |
||
device MxK |
threadgroup 48x24 |
48x48 accumulator |
Virtual GEMM (simd scoped)
threadgroup 24x48 |
||
|---|---|---|
thread 8x24 |
||
threadgroup 48x24 |
thread 24x8 |
24x24 accumulator |
Your draft (simd scoped)
device KxN |
||
|---|---|---|
thread 8x32 |
||
device MxK |
thread 32x8 |
32x32 accumulator |
Async copies are unavailable to the runtime MSL compiler. You don't need codegen at the MSL level anyway. Rather, 500 lines of MSL, which is transformed into AIR offline. At runtime, codegen happens during the AIR -> assembly transformation via MTLFunctionConstantValues. I pre-compiled the AIR and hosted it here.
For non-GEMM parts, MSL codegen could be useful. My metallib supports fused activations through MTLVisibleFunction. Would this API suffice for Tinygrad?