jamba liger fused linear+xentropy
Summary
Testing Done
- Hardware Type: <BLANK>
- [ ] run
make testto ensure correctness - [ ] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence
awesome! please make sure you add both conv (w logits and w/o logits) and unit tests. we are very focused on testing
https://github.com/linkedin/Liger-Kernel/issues/63
I added the following additional monkey patch for Jamba.
from transformers.models.jamba import modeling_jamba
if rms_norm:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
However, convergence test seems to be failing for some values in the tensor:
E Mismatch at index (0, 5): tensor1[(0, 5)] = 1.1513792276382446, tensor2[(0, 5)] = 1.1512681245803833
E Mismatch at index (0, 27): tensor1[(0, 27)] = 0.6227690577507019, tensor2[(0, 27)] = 0.6227344870567322
E Mismatch at index (0, 28): tensor1[(0, 28)] = 0.7790964841842651, tensor2[(0, 28)] = 0.7790292501449585
E Mismatch at index (0, 29): tensor1[(0, 29)] = 0.524261474609375, tensor2[(0, 29)] = 0.5243569612503052
E Mismatch at index (0, 30): tensor1[(0, 30)] = 0.8967938423156738, tensor2[(0, 30)] = 0.8968125581741333
I tracked this down to LigerRMSNorm but needs more time to investigate why there is a difference
HI @winglian created a PR towards main branch of your fork. Do you want to merge it first and then update this PR to base on that? https://github.com/winglian/Liger-Kernel/pull/1
Or I can create a separate PR to linkedin:main https://github.com/linkedin/Liger-Kernel/pull/214
@ByronHsu thoughts?
@yubofredwang if your PR captures all the changes, I'm happy to have your PR supersede mine. thanks!
@yubofredwang there are few conflicts
No activities for a long time. Closing this PR. Feel free to create a new PR if there are progress. Thanks!