[RFC] Liger FlexChunkLoss: Alignment and Distillation loss
🚀 The feature, motivation and pitch
We want to support various alignment and distillation loss functions. Refer this PR on ORPO: #362
Progress
Alignment
- [x] ORPO https://github.com/linkedin/Liger-Kernel/pull/362
- [x] CPO https://github.com/linkedin/Liger-Kernel/pull/382
- [x] DPO https://github.com/linkedin/Liger-Kernel/pull/378
- [x] SimPO https://github.com/linkedin/Liger-Kernel/pull/386
- [x] IRPO
- [x] KTO https://github.com/linkedin/Liger-Kernel/pull/475
- [ ] f-PO
Distillation
- [ ] KL divergence
- [ ] cosine_similarity
- [ ] earth_mover_distance
- [x] JSD https://github.com/linkedin/Liger-Kernel/pull/425
- [ ] KVD
Design
Approach Overview:
The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:
- Modular Optimization Process:
- Every alignment algorithm’s optimization can be broken into three key steps:
- Linear layer computation
- Loss computation
- Gradient calculation
- Every alignment algorithm’s optimization can be broken into three key steps:
- Fused Linear and Loss Computation:
- Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
- Chunking & Forward Optimization:
- Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
- We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
- Torch Compile for Kernel Optimization:
- Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.
By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.
Key Findings
By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger. References:
- #227
- https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
Interface
Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.
class Mycustomloss(FlexChunkLoss):
def loss_fn(...):
..do something here
Alternatives
No response
Additional context
No response
take DPO
I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄
@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers.
It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting CoSENTLoss, MatryokshaLoss and TripleLoss for starters https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cosentloss. Perhaps this can be its own roadmap separate to this one although the idea of chunking and fusing remains the same.
@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck?
@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k
Then i think chunk loss is still helpful given the large batch size
Then i think chunk loss is still helpful given the large batch size
Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers.
#take Simpo and Irpo since they are just extensions of CPO.
I will #take KTO as the next
A little update on kto I am working now on the tests
@Chillee FYI We are working on a set of post-training losses based on your compiled chunked loss implementation for CE. Thanks for the reference!
Update on KTO loss I am done with the loss but I have problem with assertions.I am working on it.
I was following this thread and working on a chunked, fused linear KL-divergence implementation for distillation use cases. Since distillation losses differ from preference losses, introducing a LigerFusedLinearDistillationBase parent class could be helpful.
In general, the distillation pipeline involves three key inputs: teacher_logits, student_logits, and ground_truth_label. The first two inputs are used to calculate the soft loss (KL divergence), while the latter two are used to compute the hard loss (cross-entropy). The final distillation loss is typically a weighted sum of these two components.
To leverage chunked, linear-fused optimizations, we could design the solution to accept inputs as teacher_tensor (BT, hidden_dim_teacher), student_tensor (BT, hidden_dim_student), and true_label (BT,). Using these inputs, we can apply the chunked, linear-fused approach to efficiently compute both the KL-divergence loss and the cross-entropy loss.
cc @ByronHsu, @shivam15s, @pramodith: What are your thoughts on this? Do you think it makes sense to include the cross-entropy loss as part of the DistillationBase class? Thanks for your feedback!
@hongpeng-guo yes! I like your approach it's cleaner to create a new Base class for distillation losses, we're kind of doing the same for the Alignment losses to by computing the nll (cross-entropy loss of the accepted responses inside the Base class.) https://github.com/linkedin/Liger-Kernel/blob/7e3683e23f8a9a5663913fd0ea7b0b03ea1a667b/src/liger_kernel/chunked_loss/fused_linear_preference.py#L37.
+1 on @hongpeng-guo proposal. @shivam15s can help polish the base class
Sounds good @hongpeng-guo, a separate base class for distillation is absolutely needed!
please review and comment my PR on KTO here https://github.com/linkedin/Liger-Kernel/pull/410
there is an update about #410
Is CPO-SimPO planned? This can be implemented in SimPO.
Reference: https://github.com/fe1ixxu/CPO_SIMPO
Quote
CPO and SimPO share similar objectives but have different goals. CPO adds a BC-regularizer to prevent the model from deviating too much from the preferred data distribution.
$L_{CPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \beta \log \pi_{\theta}(y_w | x) - \beta \log \pi_{\theta}(y_l | x) \Big) + \log \pi_\theta(y_w| x)\Big]$
SimPO incorporates length normalization and target reward margin to improve model performance and prevent the generation of long but low-quality sequences:
$L_{SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big) \Big]$
These two objectives can be jointly used, which we call CPO-SimPO:
$L_{CPO-SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big)+ \alpha \log \pi_\theta(y_w| x)\Big]$
@ccdv-ai I think this can be done via the existing set of hyperparams of setting compute_nll_loss=True and alpha for the BC regularizer and right now all our alignment loss functions do assume length normalization https://github.com/linkedin/Liger-Kernel/blob/bd65c47999cebc2ac3dce39447ecec051b8b6159/src/liger_kernel/chunked_loss/fused_linear_preference.py#L51
there is an update about KTO on #410
KTO PR merged: https://github.com/linkedin/Liger-Kernel/pull/475
Hi, maybe I could help take IRPO following https://arxiv.org/pdf/[2404.19733] and https://github.com/huggingface/trl/issues/1611. Based on the paper, it could follow DPO Liger support and add DPO + NLL loss. Any suggestions on it is appreciated.
Moreover, may I add BCO here since BCO trainer for Liger support is mentioned here : https://github.com/huggingface/trl/issues/2495. And BCO Liger support seems to be similar to other trainers such as DPO.