Add support for jamba model with Liger Kernel
Summary
Add support for jamba model with Liger Kernel. The following ops can be patched with Liger kernel:
- RMSNorm
- cross_entropy
- swiglu
- lce_forward
Testing Done
- Hardware Type: A100-80G-PCIe
- [x] run
make testto ensure correctness - [x] run
make checkstyleto ensure code style - [x] run
make test-convergenceto ensure convergence
CI is failing, we probably need to set use_mamba_kernels=False in the tests? Or install mamba-ssm in GPU CI
pip install . '[dev]' fails for this PR after mamba-ssm into the dependecies. The reason is that mamba-ssm has a bug in its setup.py that makes it not PEP 517 compliant(basically torch, packaging, wheel has to be installed before pip install mamba-ssm runs, otherwise it complains no module found error). Similar situation applies to causal-conv1d.
There are fixes done on both repo, but never gets merged in or released:
- https://github.com/state-spaces/mamba/pull/402#issuecomment-2283915089
- https://github.com/Dao-AILab/causal-conv1d/pull/26
My current solution is to comment out the tests until the above issues are fixed. However, I have run the convergence test locally. Any other suggestion is highly welcomed.
No activities for a long time. Closing this PR. Feel free to create a new PR if there are progress. Thanks!