Tcc0403
Tcc0403
I believe I've implemented softcap in cross entropy function correctly and the flce support for gemma2. But since gemma2 currently can't pass the test even without flce, do I need...
[bench](https://gist.github.com/Tcc0403/d1e64df7f604a9ed3878016caae64cf2) i did some benchmarks on H100, adding any torch's inplace op increases time cost by roughtly 50% (original -> with_hint = 23 -> 32 ms for 128k vocab size)....
@ByronHsu > I am wondering why the error does not happen for normal case? I left an [explanation](https://github.com/linkedin/Liger-Kernel/issues/272#issuecomment-2390882798) in issue
With @mgrabban's suggestion in #343, I made another implmentation with [mark_dirty()](https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.mark_dirty.html#torch-autograd-function-functionctx-mark-dirty). **Note**: I haven't benchmarked this new approach against current liger_ce, will do it in few days. draft [gist](https://gist.github.com/Tcc0403/73dfbbaab7cc6208598eebb0aa10f5cd) for...
Update: Here's the benchmark against liger's ce speed is slower by 17% (TODO: investigate where the overhead occurs) memory is double (i guess its because the tensors passed in `mark_dirty()`...
not required for now
@ByronHsu #take To support z loss, I just need a little add-ons to #198. I'll work on it after merging label_smoothing PR.
@shimizust Thanks for the help. I only checked the vlm related issues, haven't checked what caused the `_init_weights` errors.
Passed all tests. Ready for review!
Ignore OOM errors, the current custom CrossEntropyWithZLoss (torch.nn.module), as a ground truth implementation, has precision issue on gradients calculations with bfloat16 and reduction="sum". LigerCrossEntropyLoss in this PR has no issue...