[PPO] feat: Add LoRA support for PPO
This PR adds LoRA (Low-Rank Adaptation) support for PPO (#159)
Changes
- Added LoRA support to actor and critic configuration (see #127)
- Merge the PEFT adapter before serving the model with vLLM and unmerge afterward.
Features
- Configurable LoRA rank and alpha parameters
- Target module specification for selective adaptation
- Compatible with FSDP sharding strategy
Some known issues:
- Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic, we need some help
- No thorough testing yet
- Line 80 of fsdp_vllm.py needs to be cleaned up
params = OrderedDict((k.replace(".base_layer.", "."), v) for k, v in params.items() if not ".lora_" in k)
Relevant thread
https://github.com/volcengine/verl/issues/159
https://github.com/Jiayi-Pan/TinyZero/issues/15
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
0 out of 3 committers have signed the CLA.
:x: StephenXie
:x: Jiayi-Pan
:x: TonyLianLong
You have signed the CLA already but the status is still pending? Let us recheck it.
Hey @StephenXie, Im interested in picking this back up. How functional is this feature? Have any of the known issues been resolved, or new ones come up? And what do you mean by "Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic" - what are you referencing by "ppo_trainer" logic. Thank you!
Hi all! Thank you for your contributions. Is there an expected timeline for this PR to get merged?
Hey @StephenXie, Im interested in picking this back up. How functional is this feature? Have any of the known issues been resolved, or new ones come up? And what do you mean by "Merge Ref and Actor when LoRA is on requires modifying ppo_trainer logic" - what are you referencing by "ppo_trainer" logic. Thank you!
Hi! Thanks for your interest in helping out! We have yet to produce an experiment with the right LoRA parameters. Would really appreciate if anyone could test this out. As for "ppo_trainer" logic, I think @Jiayi-Pan can explain this more in-depth.
@StephenXie I'm not sure if this is the kind of test you are looking for:
I have a setting where I do GRPO on MATH, starting from Qwen2.5-1.5B-Instruct. I would be happy to try a Lora on the same run keeping other hyperparameters the same.
@StephenXie I'm not sure if this is the kind of test you are looking for:
I have a setting where I do GRPO on MATH, starting from Qwen2.5-1.5B-Instruct. I would be happy to try a Lora on the same run keeping other hyperparameters the same.
That'd be awesome - really appreciate this! I'm also running some experiments on gsm8k on my end
I just modified examples/ppo_trainer/run_qwen2_7b.sh by adding:
actor_rollout_ref.model.lora_rank=32\
actor_rollout_ref.model.lora_alpha=16 \
actor_rollout_ref.model.target_modules=all-linear \
critic.model.lora_rank=32\
critic.model.lora_alpha=16 \
critic.model.target_modules=all-linear \
But I got this error:
ray.exceptions.RayTaskError(AttributeError): ray::WorkerDict.actor_rollout_generate_sequences() (pid=2231957, ip=172.27.33.105, actor_id=8130c099b06d86ba6bd1754c01000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fa91b53b550>)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data132/chuxiong/code/verl/verl/single_controller/ray/base.py", line 399, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data132/chuxiong/code/verl/verl/single_controller/base/decorator.py", line 404, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data132/chuxiong/code/verl/verl/workers/fsdp_workers.py", line 474, in generate_sequences
with self.rollout_sharding_manager:
File "/mnt/data132/chuxiong/code/verl/verl/workers/sharding_manager/fsdp_vllm.py", line 87, in __enter__
self.inference_engine.sync_model_weights(params, load_format=load_format)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/llm.py", line 197, in sync_model_weights
self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py", line 405, in sync_model_weights
self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py", line 213, in sync_model_weights
self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/worker.py", line 281, in sync_model_weights
load_dtensor_weights(actor_weights, self.model_runner.model)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 365, in load_dtensor_weights
weight_loader(actor_weights, vllm_model)
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 183, in qwen2_dtensor_weight_loader
local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/data132/chuxiong/code/verl/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py", line 322, in redistribute_dtensor
local_loaded_weights = loaded_weights.full_tensor()
^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Tensor' object has no attribute 'full_tensor'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
@skepsun I suppose you have this issue after merging this branch with Main locally? This is because in the sharding manager, the dict holds pointers to un-sharded tensors, but when the context exited the underlying memory of these tensors turned into sharded ones, but the pointers are still to the unsharded tensors (there's some shenanigans with python pointers here)
You can use my branch (forked from @StephenXie's) with merge conflicts resolved and this bug fixed: https://github.com/hav4ik/verl/tree/lora-dev. Also, one thing I want to point out is veRL's FSDP wrapping is atrociously bad, especially when there's mixed precision and LoRA involved. One way to fix that is to specify +actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap, but I think there should be a more systematic approach.
@StephenXie should I make a PR to your brach? what would be the most convenient way to do that?
@skepsun I suppose you have this issue after merging this branch with Main locally? This is because in the sharding manager, the dict holds pointers to un-sharded tensors, but when the context exited the underlying memory of these tensors turned into sharded ones, but the pointers are still to the unsharded tensors (there's some shenanigans with python pointers here)
You can use my branch (forked from @StephenXie's) with merge conflicts resolved and this bug fixed: https://github.com/hav4ik/verl/tree/lora-dev. Also, one thing I want to point out is veRL's FSDP wrapping is atrociously bad, especially when there's mixed precision and LoRA involved. One way to fix that is to specify
+actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap, but I think there should be a more systematic approach.@StephenXie should I make a PR to your brach? what would be the most convenient way to do that?
Hey @hav4ik this is great! Feel free to make a PR and I can merge it right away.
@StephenXie @hav4ik Thank you for your work! I tested your fix for LoRA support and noticed that some of the vLLM model weights become zero when vllm_mode == 'customized'. I suspect this is due to a state_dict mismatch between the actor module and the vLLM module. Specifically, after applying get_peft_model to the actor module, vLLMRollout is still initialized with actor_module=self.actor_module_fsdp, which might cause inconsistencies.
Upon examining the FSDP vLLM sharding manager, I found that before sync_model_weights, the word embedding weights in the vLLM module are entirely zero. Additionally, the merged parameters from params = self.module._fsdp_wrapped_module.base_model.model.state_dict() contain only a subset of parameters, with the word embedding weights missing (Might because specified transformer_layer_cls_to_wrap). This leads to zero-weight word embeddings in the vLLM model after sync, resulting in broken generation outputs.
I also checked the vllm word embedding weights when I disable lora, the weights are non-zero before and after sync_model_weights in the sharding manager.
I'm currently working on a workaround for this issue. Have you encountered this before or have any insights?
@StephenXie @hav4ik Thank you for your work! I tested your fix for LoRA support and noticed that some of the vLLM model weights become zero when vllm_mode == 'customized'. I suspect this is due to a state_dict mismatch between the actor module and the vLLM module. Specifically, after applying
get_peft_modelto the actor module, vLLMRollout is still initialized withactor_module=self.actor_module_fsdp, which might cause inconsistencies.Upon examining the FSDP vLLM sharding manager, I found that before
sync_model_weights, the word embedding weights in the vLLM module are entirely zero. Additionally, the merged parameters fromparams = self.module._fsdp_wrapped_module.base_model.model.state_dict()contain only a subset of parameters, with the word embedding weights missing (Might because specifiedtransformer_layer_cls_to_wrap). This leads to zero-weight word embeddings in the vLLM model after sync, resulting in broken generation outputs.I also checked the vllm word embedding weights when I disable lora, the weights are non-zero before and after
sync_model_weightsin the sharding manager.I'm currently working on a workaround for this issue. Have you encountered this before or have any insights?
I have tried to walkaround by initializing vLLMRollout with huggingface model name rather than the actor module but it didn't work.
very interesting in this feature
interested
moving discussions to https://github.com/volcengine/verl/pull/1127