algorithmic-efficiency icon indicating copy to clipboard operation
algorithmic-efficiency copied to clipboard

Support FSDP in PyTorch

Open priyakasimbeg opened this issue 1 year ago • 6 comments

It is useful to shard optimizer state across devices (to save significant memory). This reflects current practice. We want to support it.

  • We want to switch from no sharding to naive model parameter sharding in both framworks.
  • We will forbid (in the rules) any hacks that change the model parallelization strategy and have workload-default sharding.
  • Allow submitters to opt-out of it on a per-workload basis.

priyakasimbeg avatar Oct 17 '24 18:10 priyakasimbeg

From meeting minutes from Michael Shi: Challenge is ensuring that JAX and PyTorch are equivalent. PyTorch should be doable by changing the DDP wrapper to the FSDP wrapper.

priyakasimbeg avatar Oct 17 '24 19:10 priyakasimbeg

...For the sake of self-referencing notes and approaches.

For the pytorch case, there are two ways for doing this:

For the Jax case, which is the one I am less familiar with:

  • Manual-parallelism : A single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data.
  • Scalax package : An external package with to write a model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs.

IFFranciscoME avatar Oct 29 '24 05:10 IFFranciscoME

Hi, I have been working on this.

  • The code I have is running on cifar (on kaggle), and it seems to be fine.
  • The wrapping I used is by size, but we would need to make it equivalent to the JAX code.
  • I am getting some errors when I try to save the model checkpoints. This may have to do with the torch version, I am not sure. You can see the branch I am working from here: https://github.com/davidtweedle/algorithmic-efficiency/tree/fsdp_cifar

Edited to add: Also, I turned off torch.compile for this workload. I think that is also due to the pytorch version.

davidtweedle avatar Nov 05 '24 21:11 davidtweedle

Thanks for the update! The model checkpoints are expected to break, because they make specific assumptions about the model if I remember correctly. If I recall correctly, some submitters ran into issues with checkpointing because the checkpointing code also makes assumptions about the optimizer state. We probably want to fix the model checkpointing though as part of this FSDP migration. But I would focus on that at a later stage and just disable it for now if it is blocking.

Regarding the torch.compile, that seems a little more problematic. When you have time could you paste a traceback of the issue w torch compile (maybe with https://gist.github.com/) of in the GH issue thread. If the fix requires updating PyTorch, we should probably bump the priority on that.

priyakasimbeg avatar Nov 07 '24 17:11 priyakasimbeg

Hi, OK for now I will disable the model checkpoints. Here is a gist of the logs for this run. https://gist.github.com/davidtweedle/a870a7dd0d409e920604565a2e08b638

I am not sure what to make of this error, yet.

Also, there is this related blog post: https://dev-discuss.pytorch.org/t/torchdynamo-update-11-making-fsdp-and-dynamo-work-together/1037

davidtweedle avatar Nov 08 '24 18:11 davidtweedle

Hi, I hope it is appropriate to give a quick update on what could be going on here. When the batch norm is updated during the training step, "module.apply" is called to update the batch norm. This is called from the FSDP wrapper of the module which asserts that the training state must be "IDLE". But calling apply from the FSDP wrapper means that the FSDP wrapper wants to all gather the different parameters, which is not necessary because all we want to do is tell the batch norm layers to keep track of the running stats. So hopefully it is possible to apply "update_batch_norm_fn" without calling module.apply from the FSDP wrapper.

davidtweedle avatar Nov 09 '24 16:11 davidtweedle

Won't fix per discussion in the WG meetings.

priyakasimbeg avatar Aug 21 '25 17:08 priyakasimbeg