verl icon indicating copy to clipboard operation
verl copied to clipboard

[PPO] feat: Add LoRA support for PPO

Open StephenXie opened this issue 1 year ago • 2 comments

This PR adds LoRA (Low-Rank Adaptation) support for PPO (#159)

Changes

  • Added LoRA support to actor and critic configuration (see #127)
  • Merge the PEFT adapter before serving the model with vLLM and unmerge afterward.

Features

  • Configurable LoRA rank and alpha parameters
  • Target module specification for selective adaptation
  • Compatible with FSDP sharding strategy

Some known issues:

  • Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic, we need some help
  • No thorough testing yet
  • Line 80 of fsdp_vllm.py needs to be cleaned up params = OrderedDict((k.replace(".base_layer.", "."), v) for k, v in params.items() if not ".lora_" in k)

StephenXie avatar Feb 05 '25 07:02 StephenXie

Relevant thread

https://github.com/volcengine/verl/issues/159

https://github.com/Jiayi-Pan/TinyZero/issues/15

Jiayi-Pan avatar Feb 05 '25 08:02 Jiayi-Pan

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
0 out of 3 committers have signed the CLA.

:x: StephenXie
:x: Jiayi-Pan
:x: TonyLianLong
You have signed the CLA already but the status is still pending? Let us recheck it.

CLAassistant avatar Feb 26 '25 00:02 CLAassistant

Hey @StephenXie, Im interested in picking this back up. How functional is this feature? Have any of the known issues been resolved, or new ones come up? And what do you mean by "Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic" - what are you referencing by "ppo_trainer" logic. Thank you!

eligotts avatar Mar 13 '25 04:03 eligotts

Hi all! Thank you for your contributions. Is there an expected timeline for this PR to get merged?

cfpark00 avatar Mar 13 '25 06:03 cfpark00

Hey @StephenXie, Im interested in picking this back up. How functional is this feature? Have any of the known issues been resolved, or new ones come up? And what do you mean by "Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic" - what are you referencing by "ppo_trainer" logic. Thank you!

Hi! Thanks for your interest in helping out! We have yet to produce an experiment with the right LoRA parameters. Would really appreciate if anyone could test this out. As for "ppo_trainer" logic, I think @Jiayi-Pan can explain this more in-depth.

StephenXie avatar Mar 13 '25 19:03 StephenXie

@StephenXie I'm not sure if this is the kind of test you are looking for:

I have a setting where I do GRPO on MATH, starting from Qwen2.5-1.5B-Instruct. I would be happy to try a Lora on the same run keeping other hyperparameters the same.

cfpark00 avatar Mar 16 '25 16:03 cfpark00

@StephenXie I'm not sure if this is the kind of test you are looking for:

I have a setting where I do GRPO on MATH, starting from Qwen2.5-1.5B-Instruct. I would be happy to try a Lora on the same run keeping other hyperparameters the same.

That'd be awesome - really appreciate this! I'm also running some experiments on gsm8k on my end

StephenXie avatar Mar 16 '25 22:03 StephenXie

I just modified examples/ppo_trainer/run_qwen2_7b.sh by adding:

    actor_rollout_ref.model.lora_rank=32\
    actor_rollout_ref.model.lora_alpha=16 \
    actor_rollout_ref.model.target_modules=all-linear \
    critic.model.lora_rank=32\
    critic.model.lora_alpha=16 \
    critic.model.target_modules=all-linear \

But I got this error:

ray.exceptions.RayTaskError(AttributeError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=2231957, ip=172.27.33.105, actor_id=8130c099b06d86ba6bd1754c01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fa91b53b550>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data132/chuxiong/code/verl/verl/single_controller/ray/base.py", line 399, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data132/chuxiong/code/verl/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data132/chuxiong/code/verl/verl/workers/fsdp_workers.py", line 474, in generate_sequences
    with self.rollout_sharding_manager:
  File "/mnt/data132/chuxiong/code/verl/verl/workers/sharding_manager/fsdp_vllm.py", line 87, in __enter__
    self.inference_engine.sync_model_weights(params, load_format=load_format)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/llm.py", line 197, in sync_model_weights
    self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py", line 405, in sync_model_weights
    self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py", line 213, in sync_model_weights
    self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/worker.py", line 281, in sync_model_weights
    load_dtensor_weights(actor_weights, self.model_runner.model)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 365, in load_dtensor_weights
    weight_loader(actor_weights, vllm_model)
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 183, in qwen2_dtensor_weight_loader
    local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 322, in redistribute_dtensor
    local_loaded_weights = loaded_weights.full_tensor()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Tensor' object has no attribute 'full_tensor'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

skepsun avatar Mar 19 '25 08:03 skepsun

@skepsun I suppose you have this issue after merging this branch with Main locally? This is because in the sharding manager, the dict holds pointers to un-sharded tensors, but when the context exited the underlying memory of these tensors turned into sharded ones, but the pointers are still to the unsharded tensors (there's some shenanigans with python pointers here)

You can use my branch (forked from @StephenXie's) with merge conflicts resolved and this bug fixed: https://github.com/hav4ik/verl/tree/lora-dev. Also, one thing I want to point out is veRL's FSDP wrapping is atrociously bad, especially when there's mixed precision and LoRA involved. One way to fix that is to specify +actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap, but I think there should be a more systematic approach.

@StephenXie should I make a PR to your brach? what would be the most convenient way to do that?

hav4ik avatar Mar 20 '25 18:03 hav4ik

@skepsun I suppose you have this issue after merging this branch with Main locally? This is because in the sharding manager, the dict holds pointers to un-sharded tensors, but when the context exited the underlying memory of these tensors turned into sharded ones, but the pointers are still to the unsharded tensors (there's some shenanigans with python pointers here)

You can use my branch (forked from @StephenXie's) with merge conflicts resolved and this bug fixed: https://github.com/hav4ik/verl/tree/lora-dev. Also, one thing I want to point out is veRL's FSDP wrapping is atrociously bad, especially when there's mixed precision and LoRA involved. One way to fix that is to specify +actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap, but I think there should be a more systematic approach.

@StephenXie should I make a PR to your brach? what would be the most convenient way to do that?

Hey @hav4ik this is great! Feel free to make a PR and I can merge it right away.

StephenXie avatar Mar 20 '25 18:03 StephenXie

@StephenXie @hav4ik Thank you for your work! I tested your fix for LoRA support and noticed that some of the vLLM model weights become zero when vllm_mode == 'customized'. I suspect this is due to a state_dict mismatch between the actor module and the vLLM module. Specifically, after applying get_peft_model to the actor module, vLLMRollout is still initialized with actor_module=self.actor_module_fsdp, which might cause inconsistencies.

Upon examining the FSDP vLLM sharding manager, I found that before sync_model_weights, the word embedding weights in the vLLM module are entirely zero. Additionally, the merged parameters from params = self.module._fsdp_wrapped_module.base_model.model.state_dict() contain only a subset of parameters, with the word embedding weights missing (Might because specified transformer_layer_cls_to_wrap). This leads to zero-weight word embeddings in the vLLM model after sync, resulting in broken generation outputs.

I also checked the vllm word embedding weights when I disable lora, the weights are non-zero before and after sync_model_weights in the sharding manager.

I'm currently working on a workaround for this issue. Have you encountered this before or have any insights?

CZWin32768 avatar Mar 26 '25 07:03 CZWin32768

@StephenXie @hav4ik Thank you for your work! I tested your fix for LoRA support and noticed that some of the vLLM model weights become zero when vllm_mode == 'customized'. I suspect this is due to a state_dict mismatch between the actor module and the vLLM module. Specifically, after applying get_peft_model to the actor module, vLLMRollout is still initialized with actor_module=self.actor_module_fsdp, which might cause inconsistencies.

Upon examining the FSDP vLLM sharding manager, I found that before sync_model_weights, the word embedding weights in the vLLM module are entirely zero. Additionally, the merged parameters from params = self.module._fsdp_wrapped_module.base_model.model.state_dict() contain only a subset of parameters, with the word embedding weights missing (Might because specified transformer_layer_cls_to_wrap). This leads to zero-weight word embeddings in the vLLM model after sync, resulting in broken generation outputs.

I also checked the vllm word embedding weights when I disable lora, the weights are non-zero before and after sync_model_weights in the sharding manager.

I'm currently working on a workaround for this issue. Have you encountered this before or have any insights?

I have tried to walkaround by initializing vLLMRollout with huggingface model name rather than the actor module but it didn't work.

CZWin32768 avatar Mar 26 '25 07:03 CZWin32768

very interesting in this feature

jacklanda avatar Mar 26 '25 09:03 jacklanda

interested

zhusq20 avatar Apr 07 '25 13:04 zhusq20

moving discussions to https://github.com/volcengine/verl/pull/1127

eric-haibin-lin avatar Apr 19 '25 03:04 eric-haibin-lin