verl icon indicating copy to clipboard operation
verl copied to clipboard

[fsdp] fix: shard LoRA modules separately to prevent dtype mismatch errors with FSDP2

Open xxnpark opened this issue 2 months ago • 1 comments

What does this PR do?

Fixes #3470

Currently, when the base model is loaded in a dtype other than fp32, its parameters end up in a different dtype than the LoRA adapters (since we do not set autocast_adapter_dtype=False and the adapter parameters are cast to fp32 by default). With FSDP1, the wrapping policies for base model modules and LoRA modules are handled separately, so this mismatch does not cause issues. For FSDP2, however, there is no equivalent handling, and base model parameters and LoRA parameters can end up under the same FSDP wrapper, leading to dtype mismatch errors. This PR fixes this issue.

Checklist Before Starting

  • [x] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/pulls?q=is%3Apr+is%3Aopen+FSDP2
  • [x] Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

Running python -m verl.trainer.main_ppo with actor_rollout_ref.actor.strategy=fsdp2 and actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 in the current main branch gives the error:

AssertionError: FSDP expects uniform original parameter dtype but got {torch.float32, torch.bfloat16}

but it runs well in this branch.

API and Usage Example

No API changes.

Design & Code Changes

This change updates the FSDP2 wrapping logic to explicitly identify LoRA modules and shard them separately from the base model modules.

Checklist Before Submitting

[!IMPORTANT] Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

xxnpark avatar Nov 22 '25 09:11 xxnpark