Why does torchtune FSDP have a lower memory footprint than accelerate?
I am trying to fine-tune Qwen2.5-32B using huggingface transformers with FSDP and a max context length of 20k on 4xA100-80GB.
I'm hitting OOM even when I use all optimizations. However, when I try to fine-tune the model with torchtune, I am able to train using the same GPU setup without any problems, even with most of the custom torchtune optimizations turned off and only using activation checkpointing + cpu offload.
Hm a part of this may be because torchtune uses fsdp2, and I was using fsdp1 through transformers.
Is there any native way to use fsdp2 through transformers trainer today? Does setting the plugin version to 2 after trainer creation work?
I'm not entirely sure what you mean, you can run the code with fsdp_version=2 in your plugin, config file or as a command line argument. If you convert the args to fsdp2 properly (we have a utility accelerate to-fsdp2 --help it will use FSDP2. Also it could be that torchtune uses less memory than accelerate even in such case.
I looked into this a little further and figured it out. The two main differences between torchtune and accelerate + transformers trainer that makes Qwen 32B trainable on torchtune are:
- Torchtune uses FSDP 2 by default, and this has a lower memory footprint with offloading + activation checkpointing on. I can't find a native way to use FSDP 2 in huggingface trainer today, but it's possible to use it by setting the
FSDP_VERSIONenv variable after training args creation. - Torchtune by default uses a chunked version of cross entropy attention. This significantly reduces peak GPU memory.
Some benchmarking below.
These benchmarks were all done using Model=Llama 3.1-8B, Precision=bf16, Batch size=1, Sequence Length=1024, bf16=False.
| Technique | Peak Memory Active (GiB) | % Change Memory vs Previous |
|---|---|---|
| FSDP Version 1 | 21.02 | |
| + Activation Checkpointing | 17.25 | -17.94% |
| + CPU Offload | 16.21 | -6.03% |
| + Shard Embedding Layer | 13.19 | -18.63% |
| + Chunked Loss | 13.32 | +0.99% |
| FSDP Version 2 | 23.07 | |
| + Activation Checkpoint | 19.68 | -14.70% |
| + CPU Offload | 13.67 | -30.52% |
| + Shard Embedding Layer | 10.88 | -20.41% |
| + Chunked Loss | 9.93 | -8.73% |