Vanishing Gradient at ~7-10k tokens multiple models/platforms
Using the JVM with this cmdl: java -server -Xmx26g --enable-preview --add-modules jdk.incubator.vector -jar llama3.jar --model Llama-3.2-3B-Instruct-Q8_0.gguf --chat -n 75000 --temperature 0 under JDK 25 Graalvm EA Win11 Ryzen 7 32 gig and also 3.2-1B model both under Win11 and Ubuntu 22.04 Ryzen 9 Hx370 same VM except -Xmx96g and -n 128000 since that machine has 128 gig I notice the gradient vanishes around 7-10k tokens and the models either repeat an endless loop of doggerel or issue only sullen 2 or 3 word repetitive responses. Going to try 3.1-8B under Ubuntu and see what happens but so far its consistently become deranged at right around the same 7-10k token mark.
All computations happen in float32, in theory this implementation should be more resilient to precision error accumulation; BUT, the operations themselves must be implemented carefully; I tried my best to make it so, something may have slipped.
The current attention implementation can be vastly improved for numerical stability; I already have a proper, faster implementation using matmuls but cannot share it yet (I'm on it). The small models are more prone to drift towards insanity with temp 0, can you the recommended temps (for reproducible results rather use e.g. --seed 42 instead of --temperature 0).
Another possible cause is the context extension, note that these models are trained on 8K contexts and extended later via RoPE tricks. I hard-coded the constants which may be completely wrong for some model variants/sizes, please note that, at the time, even llama.cpp had issues with this. The constants for the context extension are included these days in the rope_freqs.weight tensor which is ignored (I think) in this implementation. This is fixed in the new implementation I'm working on.
See https://huggingface.co/meta-llama/Llama-3.2-1B
Ok, thanks for the response. I should have read the model card more closely. I assumed the 128k context was carried over to quantized version for some reason. If it's 8k, you can't be blamed for it losing its mind at that mark. I was using temp 0 because I find it a bit too quippy at default, but I'll change all the scripts to --seed 42.
This is great work, I have big plans for it for robotics so even at longer context I will have to figure out some tricks to extend it to a reliable long horizon.
I wrote some code to call out the RKNN NPU on the RK3588 via JNI, and I see they added matmul to the API recently, so if the 1B model works out I might try to see if I can patch that in and get that up on the Rock5B SBC. If we can wring max performance out of that chip it might be a great edge solution even if its just an extra 4 or 5 TOPS on the NPU.
I eagerly await the new implementation! Thanks again for the great work!