trl icon indicating copy to clipboard operation
trl copied to clipboard

DPO Llava and PaliGemma support

Open qgallouedec opened this issue 1 year ago • 1 comments

Closes #1784

Llava:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_llava_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules=all-linear

PaliGemma:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path google/paligemma-3b-pt-224 \
    --trust_remote_code \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_paligemma_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16

qgallouedec avatar Jul 03 '24 15:07 qgallouedec

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

the tests are failing due to some time-outs i believe

kashif avatar Jul 08 '24 11:07 kashif

Re-running solved the issue

qgallouedec avatar Jul 08 '24 17:07 qgallouedec

Closes #1784

Llava:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_llava_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules=all-linear

PaliGemma:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path google/paligemma-3b-pt-224 \
    --trust_remote_code \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_paligemma_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16

how to set other parameters, like learning rate?

Liuziyu77 avatar Jul 22 '24 10:07 Liuziyu77

Closes #1784

Llava:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_llava_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules=all-linear

PaliGemma:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path google/paligemma-3b-pt-224 \
    --trust_remote_code \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_paligemma_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16

I run the same code for llava-hf/llava-1.5-7b-hf, but it says: 'LlavaProcessor' object has no attribute 'apply_chat_template' I tried transformers==4.37.2 and 4.36.2 but failed. what's your transformers version?

Liuziyu77 avatar Jul 22 '24 10:07 Liuziyu77

Closes #1784

Llava:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path llava-hf/llava-1.5-7b-hf \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_llava_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules=all-linear

PaliGemma:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path google/paligemma-3b-pt-224 \
    --trust_remote_code \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir dpo_paligemma_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16

If I want to support multi-pics DPO, where can I modify the code? Thank you for your patient very much!

Liuziyu77 avatar Jul 22 '24 14:07 Liuziyu77

@Liuziyu77 do you mind opening an issue for your questions? It's way easier for us to keep track of them

qgallouedec avatar Aug 06 '24 13:08 qgallouedec

Could we possibly try llava_next_vicuna_7b ?

g0kul6 avatar Oct 03 '24 19:10 g0kul6

You can give it a try. Feel free to open an issue, it will help us tracking progress

qgallouedec avatar Oct 03 '24 19:10 qgallouedec