gemma.cpp
gemma.cpp copied to clipboard
Near-term roadmap
We're sharing a roadmap of ideas for improving and speeding up Gemma. If you'd like to join in and help us get there faster, please reach out so we can coordinate :)
Pending
MatMul2
- [x] (jan-wassenberg) Bind tensors to NUMA node
- [x] De-templatize matmul; wrap combinations like bf16, bf16, float into normal functions to speed up compile time and enable per-tensor type dispatch
- [x] Support bf16 output from matmul
- [x] Add bf16 support to AddFrom and MulByConstAndAdd
- [ ] .. then change most activations to bf16: C1, C2, ffw_out, maybe att_sums and att_out
- [x] For F32 input, use F32 mul instead of forcing conversion to bf16
- [ ] Control over GEMM threading strategy
Faster startup
- [x] IOBatch interface calling preadv on Linux
Infra improvements/simplification
- [x] KVCache use MatPtr, Row() instead of CachePosSize
- [x] in ops.h, pass RowVectorBatch to functions rather than pointers
- [x] Replace RowVectorBatch with MatStorageT
- [ ] Replace MMAutoTune with AutoTune from Highway, update Highway version
- [x] rename compression/shared.h -> types.h
Optimizations
- [x] Replace attention matVec with matmul - requires reshaping a matrix
- [x] Use MatMul in EmbedImagePatches
- [x] Fuse softmax and sampling
- [x] Vectorize RoPE
- [ ] Unroll WeightedSumV (4 V rows at a time)
- [ ] Flash attention
- [ ] Improved KV and
attlayout - [ ] SFP embedding instead of bf16 - convert at load-time until the exporter is updated
- [x] Vectorize RMSNorm
- [x] Smaller KVCache: bf16, possibly reorder for better locality
Usability
- [ ] warn if unknown arguments given. std::map of known arg names?
- [x] multiple .cc files to speed up builds
- [x] move eval/test files to tests/
- [ ] Ctrl+C signal handler to ensure profiler results are printed without requiring %q input
- [x] add --prompt flag to run.cc
- [ ] random prompt generation for debug_prompt.cc
- [ ] gemma_test: ensure deterministic output (same output given two of the same prompts)
Threading
- [x] (jan-wassenberg) detect: total #logical, per-logical: package, chiplet, core, smt
- [x] (jan-wassenberg) detect: CPU name, L2D/L3 size
- [x] (Z.A.) CCX-aware pinning - ready, awaiting Highway 1.2 release
- [ ] (jan-wassenberg) more efficient ThreadPool (collaborative work requesting, not stealing)
- [x] command line arg to disable pinning
- [x] detect NUMA
Done
[x] Compression
- [x] (pculliton, A.R.) Eval infrastructure
- [x] (A.R.) Arbiter model for eval
- [x] (Ray) add metadata to tensors, remove RawWeights
- [x] add TOC to BlobStore
[x] File format
- [x] store ModelInfo in weights BlobStore
- [x] store tensor info in BlobStore
- [x] store tokenizer in BlobStore
[x] New models
- [x] (Daniel) Support PaliGemma
- [x] Split Model into ModelFamily and ModelSize
- [x] (jan-wassenberg) Land single-file format
[x] General infra
- [x] (pculliton) Python wrapper
- [x] (pculliton, ...) Improved CI: run on Kaggle infra
- [x] AuxOut to hold timing info instead of printing in GenerateImpl.
- [x] Sampling struct holds rng and temperature, to reduce length of args
- [x] (P. C.) use new HWY_EXPORT_T to simplify dispatch - ready, awaiting Highway 1.2 release
[x] Dot product
- [x] Add
_mm*_dpbf16_pstoHWY_AVX3_SPRandHWY_AVX3_ZEN4targets, plus defineHWY_NATIVE_DOT_BF16inset_macros-inl.h - [x] Faster SFP decode via table lookup
- [x] Add new
NEON_*target that usesvbfdotforReorderWidenMulAccumulate - [x] If
!defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16, decompress bf16->f32 to temp array before MatVec (idea by Samuel, thank you!) - in #166 - [x] Apply even/odd trick to SFP
[x] Matmul
- [x] (pculliton) implement basic matmul and test. Not using BLAS because we want to fuse matmul and decompression.
- [x] (pculliton) 4x4 unrolled and vectorized matmul
- [x] (szabadka, B.B.) Update Prefill to use matmul (activation @ weights) instead of MatVec. Almost there.
- [x] Fused decompression inside matmul
- [x] Support offsets within the matrix, required by some call sites
- [x] (jan-wassenberg) Decompress weights to bf16 when native
- [x] (jan-wassenberg) Cache-aware tiling/packing
- [x] (jan-wassenberg) NUMA aware
- [x] (jan-wassenberg) 64-bit precision
- [x] (B.B.) Larger batch size
- [x] (A.V.) Avoid allocations for decompression
Making good progress :)
Is Paligemma part of the scope of gemma.cpp?
Let's discuss in #185 :)