accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Why does torchtune FSDP have a lower memory footprint than accelerate?

Open jasper-lu opened this issue 10 months ago • 3 comments

I am trying to fine-tune Qwen2.5-32B using huggingface transformers with FSDP and a max context length of 20k on 4xA100-80GB.

I'm hitting OOM even when I use all optimizations. However, when I try to fine-tune the model with torchtune, I am able to train using the same GPU setup without any problems, even with most of the custom torchtune optimizations turned off and only using activation checkpointing + cpu offload.

jasper-lu avatar Jun 20 '25 20:06 jasper-lu

Hm a part of this may be because torchtune uses fsdp2, and I was using fsdp1 through transformers.

Is there any native way to use fsdp2 through transformers trainer today? Does setting the plugin version to 2 after trainer creation work?

jasper-lu avatar Jun 21 '25 03:06 jasper-lu

I'm not entirely sure what you mean, you can run the code with fsdp_version=2 in your plugin, config file or as a command line argument. If you convert the args to fsdp2 properly (we have a utility accelerate to-fsdp2 --help it will use FSDP2. Also it could be that torchtune uses less memory than accelerate even in such case.

S1ro1 avatar Jun 22 '25 13:06 S1ro1

I looked into this a little further and figured it out. The two main differences between torchtune and accelerate + transformers trainer that makes Qwen 32B trainable on torchtune are:

  1. Torchtune uses FSDP 2 by default, and this has a lower memory footprint with offloading + activation checkpointing on. I can't find a native way to use FSDP 2 in huggingface trainer today, but it's possible to use it by setting the FSDP_VERSION env variable after training args creation.
  2. Torchtune by default uses a chunked version of cross entropy attention. This significantly reduces peak GPU memory.

Some benchmarking below. These benchmarks were all done using Model=Llama 3.1-8B, Precision=bf16, Batch size=1, Sequence Length=1024, bf16=False.

Technique Peak Memory Active (GiB) % Change Memory vs Previous
FSDP Version 1 21.02
+ Activation Checkpointing 17.25 -17.94%
+ CPU Offload 16.21 -6.03%
+ Shard Embedding Layer 13.19 -18.63%
+ Chunked Loss 13.32 +0.99%
FSDP Version 2 23.07
+ Activation Checkpoint 19.68 -14.70%
+ CPU Offload 13.67 -30.52%
+ Shard Embedding Layer 10.88 -20.41%
+ Chunked Loss 9.93 -8.73%

jasper-lu avatar Jun 28 '25 01:06 jasper-lu