Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss

Open shivam15s opened this issue 1 year ago • 23 comments

🚀 The feature, motivation and pitch

We want to support various alignment and distillation loss functions. Refer this PR on ORPO: #362

Progress

Alignment

  • [x] ORPO https://github.com/linkedin/Liger-Kernel/pull/362
  • [x] CPO https://github.com/linkedin/Liger-Kernel/pull/382
  • [x] DPO https://github.com/linkedin/Liger-Kernel/pull/378
  • [x] SimPO https://github.com/linkedin/Liger-Kernel/pull/386
  • [x] IRPO
  • [x] KTO https://github.com/linkedin/Liger-Kernel/pull/475
  • [ ] f-PO

Distillation

  • [ ] KL divergence
  • [ ] cosine_similarity
  • [ ] earth_mover_distance
  • [x] JSD https://github.com/linkedin/Liger-Kernel/pull/425
  • [ ] KVD

Design

Approach Overview:

The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:

  1. Modular Optimization Process:
    • Every alignment algorithm’s optimization can be broken into three key steps:
      • Linear layer computation
      • Loss computation
      • Gradient calculation
  2. Fused Linear and Loss Computation:
    • Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
  3. Chunking & Forward Optimization:
    • Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
    • We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
  4. Torch Compile for Kernel Optimization:
    • Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.

By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.

Key Findings

By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger. References:

  1. #227
  2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py

Interface

Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies. A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.

class Mycustomloss(FlexChunkLoss):
  def loss_fn(...):
    ..do something here

Alternatives

No response

Additional context

No response

shivam15s avatar Nov 08 '24 22:11 shivam15s

take DPO

austin362667 avatar Nov 13 '24 02:11 austin362667

I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄

hongpeng-guo avatar Nov 13 '24 10:11 hongpeng-guo

@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers.

It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting CoSENTLoss, MatryokshaLoss and TripleLoss for starters https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cosentloss. Perhaps this can be its own roadmap separate to this one although the idea of chunking and fusing remains the same.

pramodith avatar Nov 13 '24 11:11 pramodith

@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck?

ByronHsu avatar Nov 13 '24 17:11 ByronHsu

@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k

pramodith avatar Nov 13 '24 18:11 pramodith

Then i think chunk loss is still helpful given the large batch size

ByronHsu avatar Nov 13 '24 22:11 ByronHsu

Then i think chunk loss is still helpful given the large batch size

Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers.

pramodith avatar Nov 13 '24 22:11 pramodith

#take Simpo and Irpo since they are just extensions of CPO.

pramodith avatar Nov 15 '24 11:11 pramodith

I will #take KTO as the next

vulkomilev avatar Nov 19 '24 18:11 vulkomilev

A little update on kto I am working now on the tests

vulkomilev avatar Nov 20 '24 22:11 vulkomilev

@Chillee FYI We are working on a set of post-training losses based on your compiled chunked loss implementation for CE. Thanks for the reference!

ByronHsu avatar Nov 22 '24 17:11 ByronHsu

Update on KTO loss I am done with the loss but I have problem with assertions.I am working on it.

vulkomilev avatar Nov 23 '24 15:11 vulkomilev

I was following this thread and working on a chunked, fused linear KL-divergence implementation for distillation use cases. Since distillation losses differ from preference losses, introducing a LigerFusedLinearDistillationBase parent class could be helpful.

In general, the distillation pipeline involves three key inputs: teacher_logits, student_logits, and ground_truth_label. The first two inputs are used to calculate the soft loss (KL divergence), while the latter two are used to compute the hard loss (cross-entropy). The final distillation loss is typically a weighted sum of these two components.

To leverage chunked, linear-fused optimizations, we could design the solution to accept inputs as teacher_tensor (BT, hidden_dim_teacher), student_tensor (BT, hidden_dim_student), and true_label (BT,). Using these inputs, we can apply the chunked, linear-fused approach to efficiently compute both the KL-divergence loss and the cross-entropy loss.

cc @ByronHsu, @shivam15s, @pramodith: What are your thoughts on this? Do you think it makes sense to include the cross-entropy loss as part of the DistillationBase class? Thanks for your feedback!

hongpeng-guo avatar Nov 25 '24 06:11 hongpeng-guo

@hongpeng-guo yes! I like your approach it's cleaner to create a new Base class for distillation losses, we're kind of doing the same for the Alignment losses to by computing the nll (cross-entropy loss of the accepted responses inside the Base class.) https://github.com/linkedin/Liger-Kernel/blob/7e3683e23f8a9a5663913fd0ea7b0b03ea1a667b/src/liger_kernel/chunked_loss/fused_linear_preference.py#L37.

pramodith avatar Nov 25 '24 10:11 pramodith

+1 on @hongpeng-guo proposal. @shivam15s can help polish the base class

ByronHsu avatar Nov 25 '24 17:11 ByronHsu

Sounds good @hongpeng-guo, a separate base class for distillation is absolutely needed!

shivam15s avatar Nov 25 '24 19:11 shivam15s

please review and comment my PR on KTO here https://github.com/linkedin/Liger-Kernel/pull/410

vulkomilev avatar Nov 28 '24 20:11 vulkomilev

there is an update about #410

vulkomilev avatar Dec 04 '24 22:12 vulkomilev

Is CPO-SimPO planned? This can be implemented in SimPO.

Reference: https://github.com/fe1ixxu/CPO_SIMPO

Quote

CPO and SimPO share similar objectives but have different goals. CPO adds a BC-regularizer to prevent the model from deviating too much from the preferred data distribution.

$L_{CPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \beta \log \pi_{\theta}(y_w | x) - \beta \log \pi_{\theta}(y_l | x) \Big) + \log \pi_\theta(y_w| x)\Big]$

SimPO incorporates length normalization and target reward margin to improve model performance and prevent the generation of long but low-quality sequences:

$L_{SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big) \Big]$

These two objectives can be jointly used, which we call CPO-SimPO:

$L_{CPO-SimPO}(\pi_\theta;U) = -E_{(x,y_w,y_l) \sim \mathcal{D}} \Big[ \log \sigma \Big( \frac{\beta}{|y_w|} \log \pi_{\theta}(y_w | x) - \frac{\beta}{|y_l|} \log \pi_{\theta}(y_l | x) - \gamma \Big)+ \alpha \log \pi_\theta(y_w| x)\Big]$

ccdv-ai avatar Dec 08 '24 01:12 ccdv-ai

@ccdv-ai I think this can be done via the existing set of hyperparams of setting compute_nll_loss=True and alpha for the BC regularizer and right now all our alignment loss functions do assume length normalization https://github.com/linkedin/Liger-Kernel/blob/bd65c47999cebc2ac3dce39447ecec051b8b6159/src/liger_kernel/chunked_loss/fused_linear_preference.py#L51

pramodith avatar Dec 08 '24 15:12 pramodith

there is an update about KTO on #410

vulkomilev avatar Dec 11 '24 21:12 vulkomilev

KTO PR merged: https://github.com/linkedin/Liger-Kernel/pull/475

hebiao064 avatar Jan 22 '25 21:01 hebiao064

Hi, maybe I could help take IRPO following https://arxiv.org/pdf/[2404.19733] and https://github.com/huggingface/trl/issues/1611. Based on the paper, it could follow DPO Liger support and add DPO + NLL loss. Any suggestions on it is appreciated.

Moreover, may I add BCO here since BCO trainer for Liger support is mentioned here : https://github.com/huggingface/trl/issues/2495. And BCO Liger support seems to be similar to other trainers such as DPO.

1485840691-eng avatar Feb 22 '25 03:02 1485840691-eng