llama-bench : Add `--override-tensors` arg
A small group over at BeaverAI have been making extensive use of the --override-tensors (-ot) flag for running massive MOE models faster by keeping attention on the GPU and offloading the expert FFNs to the CPU. Informal experimentation in llama-server or llama-cli doesn't compare to the proper llama-bench, though, so this PR adds the --override-tensors arg (and the -ot short form) to llama-bench.
I noticed the // FIXME about leaking memory in args.cpp when copying the --override-tensors argument parsing, and chose to stamp null terminators into the argv, rather than accept the memory leak, as llama-bench calls parse_cmd_params only once. Let me know if you'd like that swapped out for the memory-leaking version from the common arg parser, as it's only a handful of user-entered bytes leaked.
Also planning to do some documentation of --override-tensors a little later on, as it's proving very useful and we'd love to spread the word.
Sketchy performance comparison on my laptop to show why --override-tensors helps MoE models. I set longer context lengths than the standard llama-bench to emphasize why keeping the attention operations on the GPU is important.
My hardware is an ASUS TUF A14 gaming laptop, so a Ryzen 9 AI HX 370 with 7500MHz LPDDR5 and an RTX 4060 Mobile. I run it for these tests in the ASUS-standard "Turbo" mode.
First, a CPU-only test on my hardware (used 0.3 GB of VRAM during prompt processing)
.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 0 -p 4096 -n 4096
| model | size | params | backend | ngl | threads | test | t/s |
|---|---|---|---|---|---|---|---|
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 0 | 8 | pp4096 | 631.85 ± 15.23 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 0 | 8 | tg4096 | 44.04 ± 1.76 |
Next, running with -ngl 4 to offload some layers. I use such a low layer offload number to limit the VRAM use to just 2.2 GB, e.g. pretending this is a massive model that doesn't fit. Didn't have the time to spend re-running the tests until I got exactly the same VRAM use as with -ot below.
.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 4 -p 4096 -n 4096
| model | size | params | backend | ngl | threads | test | t/s |
|---|---|---|---|---|---|---|---|
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 4 | 8 | pp4096 | 750.98 ± 4.15 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 4 | 8 | tg4096 | 36.27 ± 0.19 |
Next, enabling the --override-tensors via the -ot short-form. Because of the CPU-overridden tensors, we can set -ngl 99 and still only use 1.3GB of VRAM.
.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf -t 8 -ngl 99 -ot "\d+\.ffn_.*exp.=CPU" -p 4096 -n 4096
| model | size | params | backend | ngl | threads | test | t/s |
|---|---|---|---|---|---|---|---|
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 99 | 8 | pp4096 | 736.91 ± 2.13 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,RPC | 99 | 8 | tg4096 | 46.26 ± 0.93 |
Effects are significantly more pronounced in larger MoE models, especially with more experts and some experts that are re-used for every pass (e.g. Llama 4 Scout and Maverick, although those models are beyond my devices' capabilities.) I tried to demonstrate with Deepseek-V2-Lite, but ran into CUDA errors if I tried to apply flash attention, cache quantization, or override-tensors. I don't have the experience with llama.cpp's codebase to track those down, but another Beaver has suggested it may be related to #12798
PR #12891 has resolved my issue running flash attention and override-tensors with Deepseek-V2-Lite. Some performance numbers for that, same hardware as my last set:
CPU Only (Used 0.8GB of VRAM during prompt processing)
.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
-p 4096 -n 4096 -t 8 ^
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 0
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 0 | 8 | q8_0 | q8_0 | 1 | pp4096 | 76.48 ± 2.94 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 0 | 8 | q8_0 | q8_0 | 1 | tg4096 | 20.13 ± 1.65 |
Completely Filled GPU (Used 8.0GB of VRAM during prompt processing)
.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
-p 4096 -n 4096 -t 8 ^
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 14
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 14 | 8 | q8_0 | q8_0 | 1 | pp4096 | 102.89 ± 0.54 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 14 | 8 | q8_0 | q8_0 | 1 | tg4096 | 15.36 ± 1.39 |
Comparable VRAM GPU (Used 2.8GB of VRAM during prompt processing)
.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
-p 4096 -n 4096 -t 8 ^
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 4
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 4 | 8 | q8_0 | q8_0 | 1 | pp4096 | 61.07 ± 10.01 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 4 | 8 | q8_0 | q8_0 | 1 | tg4096 | 13.25 ± 0.36 |
Override-Tensors Run (Used 1.8GB of VRAM during prompt processing)
.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
-p 4096 -n 4096 -t 8 ^
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "\d+\.ffn_.*exp.=CPU"
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 8 | q8_0 | q8_0 | 1 | pp4096 | 100.06 ± 1.92 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 8 | q8_0 | q8_0 | 1 | tg4096 | 13.11 ± 0.21 |
Tuned Override-Tensors (Used 6.3GB of VRAM during prompt processing)
This run, I'm leaving 6 of the 26 layers' conditional experts on the GPU as well as all the shexp (shared expert) layers, to try to better fill the VRAM and hopefully get the full best of both worlds.
.\build\bin\Release\llama-bench.exe -m ..\models\DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf ^
-p 4096 -n 4096 -t 8 ^
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "[12]\d\.ffn_.*exps.=CPU"
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 8 | q8_0 | q8_0 | 1 | pp4096 | 63.12 ± 0.37 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 8 | q8_0 | q8_0 | 1 | tg4096 | 13.98 ± 0.13 |
Turns out my GPU was far more underpowered than I expected, but y'all can see the point of being able to benchmark this kind of thing.
Ran another set of experiments on another device (RTX 3070 and an AMD Ryzen 7 5800X 8-Core with two sticks of 2133MHz DDR4)
CPU Only (Used 836MB of VRAM during prompt processing)
./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
-p 4096 -n 4096 -t 4 \
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 0
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 0 | 4 | q8_0 | q8_0 | 1 | pp4096 | 62.50 ± 0.09 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 0 | 4 | q8_0 | q8_0 | 1 | tg4096 | 9.51 ± 0.19 |
Full GPU (Used 7626MB of VRAM during prompt processing)
./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
-p 4096 -n 4096 -t 4 \
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 13
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 13 | 4 | q8_0 | q8_0 | 1 | pp4096 | 67.20 ± 0.11 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 13 | 4 | q8_0 | q8_0 | 1 | tg4096 | 11.80 ± 0.03 |
Comparable VRAM GPU (Used 2930MB of VRAM during prompt processing)
./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
-p 4096 -n 4096 -t 4 \
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 4
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 4 | 4 | q8_0 | q8_0 | 1 | pp4096 | 62.74 ± 0.14 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 4 | 4 | q8_0 | q8_0 | 1 | tg4096 | 10.13 ± 0.01 |
Override-Tensors Full CPU Experts (except shared) (Used 2276MB of VRAM during prompt processing)
./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
-p 4096 -n 4096 -t 4 \
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "\d+.ffn_.*exps.=CPU"
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 4 | q8_0 | q8_0 | 1 | pp4096 | 62.79 ± 0.13 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 4 | q8_0 | q8_0 | 1 | tg4096 | 11.80 ± 0.03 |
Override-Tensors Tuned (Used 7034MB of VRAM during prompt processing)
./build/bin/llama-bench -m ../models/DeepSeek-Coder-V2-Lite-Base-Q6_K_L.gguf \
-p 4096 -n 4096 -t 4 \
-fa 1 -ctk q8_0 -ctv q8_0 -ngl 99 -ot "[2.]\d.ffn_.*exps.=CPU"
| model | size | params | backend | ngl | threads | type_k | type_v | fa | test | t/s |
|---|---|---|---|---|---|---|---|---|---|---|
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 4 | q8_0 | q8_0 | 1 | pp4096 | 66.80 ± 0.06 |
| deepseek2 16B Q6_K | 13.56 GiB | 15.71 B | CUDA,RPC | 99 | 4 | q8_0 | q8_0 | 1 | tg4096 | 14.05 ± 0.02 |
Now, as the processor doesn't have AVX512 and relatively high bandwidth memory, we see the GPU eeking out a performance boost and override-tensors helping significantly.
You can also use this to offload the entire KV cache to GPU while keeping the model on CPU: -ngl 999 -ot "^.*$=CPU"
Got it @slaren. As for splitting the test grid entries, would you prefer that I use semicolons instead of commas the same way that we do for tensor split? Or should I reverse their behavior? Or should I require separate instances of the -ot flag?
EDIT: Going to update the PR with the same behaviour as tensor split for now, just so that I can get started.
Either way is fine as long as it is consistent, I don't mind if the way -ts is parsed is changed, but both options should behave in the same way.
I've implemented the behaviour the same way as tensor-split, for now. That is, ; is now the internal separator for different overrides and , is now the separator between test inputs, consistent with every other test argument. I've also tested this with multiple instances of -ot. Samples below (though do not trust these performance numbers as I was in a Zoom meeting simultaneously.)
.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf ^
-t 6 -ts 1;0 -pg 2048,128 ^
-ngl 99 -ot "\d+\.ffn_.*exp.=CPU,1\d\.ffn_.*exp.=CPU,1\d\.ffn_.*exps.=CPU"
| model | size | params | backend | ngl | threads | ts | ot | test | t/s |
|---|---|---|---|---|---|---|---|---|---|
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | pp512 | 751.36 ± 6.24 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | tg128 | 48.84 ± 2.26 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | pp2048+tg128 | 408.56 ± 1.34 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | pp512 | 1436.22 ± 15.76 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | tg128 | 58.47 ± 0.47 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | pp2048+tg128 | 548.25 ± 32.14 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | pp512 | 1207.03 ± 51.72 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | tg128 | 50.43 ± 1.20 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | pp2048+tg128 | 534.71 ± 10.05 |
.\build\bin\Release\llama-bench.exe -m ..\models\OLMoE-1B-7B-0924-Instruct-Q8_0.gguf ^
-t 6 -ts 1;0 -pg 2048,128 ^
-ngl 99 -ot "\d+\.ffn_.*exp.=CPU" -ot "1\d\.ffn_.*exp.=CPU" -ot "1\d\.ffn_.*exps.=CPU"
| model | size | params | backend | ngl | threads | ts | ot | test | t/s |
|---|---|---|---|---|---|---|---|---|---|
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | pp512 | 747.16 ± 5.14 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | tg128 | 39.11 ± 1.93 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | \d+.ffn_.*exp.=CPU | pp2048+tg128 | 373.83 ± 4.53 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | pp512 | 1403.84 ± 16.44 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | tg128 | 51.81 ± 0.40 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exp.=CPU | pp2048+tg128 | 548.47 ± 4.81 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | pp512 | 1400.16 ± 10.49 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | tg128 | 51.18 ± 0.95 |
| olmoe A1.7B Q8_0 | 6.85 GiB | 6.92 B | CUDA,Vulkan,RPC | 99 | 6 | 1.00 | 1\d.ffn_.*exps.=CPU | pp2048+tg128 | 545.08 ± 6.43 |
I understand now why all of the other functions in that file were marked static. I'll see if I can get my linux desktop up and make sure I run the full CI before pushing the next change to my branch, to figure out why all these functions need static markings.
All I can say is the CPU CI ran to completion on my Ubuntu 22.04 machine with no errors I was aware of. I'll try to take a look at this again tomorrow or Friday.
What's the minimum NVCC version for the CUDA CI? I have CUDA toolkit 12.4.131, which the CMake configuration finds both in and out of CI, and builds happily when I build llama.cpp for my own purposes, but while "Compiling the CUDA compiler identification source file "CMakeCUDACompilerId.cu" failed" I get
Compiler: /usr/local/cuda/bin/nvcc
Build flags:
Id flags: --keep;--keep-dir;tmp;-gencode=arch=compute_,code=sm_ -v
The output was:
1
nvcc fatal : Unsupported gpu architecture 'compute_'
Call Stack (most recent call first):
/usr/share/cmake-3.22/Modules/CMakeDetermineCompilerId.cmake:6 (CMAKE_DETERMINE_COMPILER_ID_BUILD)
/usr/share/cmake-3.22/Modules/CMakeDetermineCompilerId.cmake:48 (__determine_compiler_id_test)
/usr/share/cmake-3.22/Modules/CMakeDetermineCUDACompiler.cmake:298 (CMAKE_DETERMINE_COMPILER_ID)
ggml/src/ggml-cuda/CMakeLists.txt:25 (enable_language)
Tried the Vulkan CI (because I can't run the CUDA CI on my desktop with my nvcc, apparently) and that failed on an unused parameter in a file my change didn't even touch, both before and after merging the latest master:
llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp:4209:175: error: unused parameter ‘src1_type’ [-Werror=unused-parameter]
4209 | static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
| ~~~~~~~~~~^~~~~~~~~
cc1plus: all warnings being treated as errors
make[2]: *** [ggml/src/ggml-vulkan/CMakeFiles/ggml-vulkan.dir/build.make:189: ggml/src/ggml-vulkan/CMakeFiles/ggml-vulkan.dir/ggml-vulkan.cpp.o] Error 1
make[2]: *** Waiting for unfinished jobs....
make[1]: *** [CMakeFiles/Makefile2:1814: ggml/src/ggml-vulkan/CMakeFiles/ggml-vulkan.dir/all] Error 2
make: *** [Makefile:146: all] Error 2
Let me know if any next-steps you have for me. I'd love to be able to test this properly locally, rather than hitting the GitHub CI. I don't see any errors from CPU CI when run locally, and I'm unable to run the CUDA and Vulkan CI, so I'm not sure what other actions I can take to make this small PR compatible (beyond bringing it up to the latest master, as I've just done.)
The minimum CUDA version is 11.7, you should be good with 12.4. I am not sure what happened there, it seems that it failed to pick which architecture to build for. You could try manually specifying the architecture by setting CMAKE_CUDA_ARCHITECTURES=89. 89 should work for your 4060, but check the compute capability if you are using a different device.
Adding CMAKE_CUDA_ARCHITECTURES=86 (for the 3070 in my desktop) resulted in the same message. It's possible that my driver and NVCC CUDA versions are desynced, as nvidia-smi reports CUDA version 12.7.
Digging further into the online CI failures, I'm noticing they're only occurring to arm64 and riscv64, and all of them occur during a sudo dpkg --add-architecture and not during llama.cpp's compilation nor runtime. We see the same failure on the riscv64 architecture over here, https://github.com/ggml-org/llama.cpp/actions/runs/14492913171/job/40711973425?pr=12955, on a PR approved 8 hours ago. Could the online checks be a flaky CI situation?
I was able to run the CUDA CI on my x86_64 laptop's 4060 using CUDA 12.8 installed in WSL2 Ubuntu 24.04. No errors.
Are you able to re-run the failed GitHub checks to give the package list retrieval from Azure a chance of success? I can't determine any connection it would have to this PR.
Don't worry about the failed linux-cross CI, it's not related to this PR. I will review this again when I have a chance.
This would be a really useful addition for benchmarking deepseek-r1 and deepseek-v3!
@4onen Hi, I’m very interested in the --override-tensors parameter as well. Could you please share any additional public documentation or resources I could review?
Also planning to do some documentation of --override-tensors a little later on, as it's proving very useful and we'd love to spread the word.
@4onen Hi, I’m very interested in the --override-tensors parameter as well. Could you please share any additional public documentation or resources I could review?
Hey @jklincn, I'm afraid I'm a little too hammered with my school & research work to make progress on that. I asked around and the best I'm currently aware of is https://github.com/ggml-org/llama.cpp/pull/11397 (the original PR for the feature.) Not terribly detailed but makes clear how to use regex to select layers. On HuggingFace you can get a listing of all the layers in the model to help you build the regex you need.
Good luck!