[fsdp] fix: shard LoRA modules separately to prevent dtype mismatch errors with FSDP2
What does this PR do?
Fixes #3470
Currently, when the base model is loaded in a dtype other than fp32, its parameters end up in a different dtype than the LoRA adapters (since we do not set autocast_adapter_dtype=False and the adapter parameters are cast to fp32 by default). With FSDP1, the wrapping policies for base model modules and LoRA modules are handled separately, so this mismatch does not cause issues. For FSDP2, however, there is no equivalent handling, and base model parameters and LoRA parameters can end up under the same FSDP wrapper, leading to dtype mismatch errors. This PR fixes this issue.
Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pulls?q=is%3Apr+is%3Aopen+FSDP2
- [x] Format the PR title as
[{modules}] {type}: {description}(This will be checked by the CI)-
{modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data - If this PR involves multiple modules, separate them with
,like[megatron, fsdp, doc] -
{type}is infeat,fix,refactor,chore,test - If this PR breaks any API (CLI arguments, config, function signature, etc.), add
[BREAKING]to the beginning of the title. - Example:
[BREAKING][fsdp, megatron] feat: dynamic batching
-
Test
Running python -m verl.trainer.main_ppo with actor_rollout_ref.actor.strategy=fsdp2 and actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 in the current main branch gives the error:
AssertionError: FSDP expects uniform original parameter dtype but got {torch.float32, torch.bfloat16}
but it runs well in this branch.
API and Usage Example
No API changes.
Design & Code Changes
This change updates the FSDP2 wrapping logic to explicitly identify LoRA modules and shard them separately from the base model modules.
Checklist Before Submitting
[!IMPORTANT] Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
- [x] Read the Contribute Guide.
- [x] Apply pre-commit checks:
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always - [x] Add / Update the documentation.
- [x] Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in the
ci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)