Pytorch 2 compile + fsdp + transformers crash
System Info
-
transformersversion: 4.28.0.dev0 - Platform: Linux-5.10.147+-x86_64-with-glibc2.31
- Python version: 3.9.16
- Huggingface_hub version: 0.13.3
- PyTorch version (GPU?): 2.0.0+cu117 (False)
- Tensorflow version (GPU?): 2.11.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.6.7 (cpu)
- Jax version: 0.4.6
- JaxLib version: 0.4.6
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
text models: @ArthurZucker and @younesbelkada trainer: @sgugger PyTorch: @sgugger
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Training the official "run_clm.py" script works on TPU only when:
- base training.
- base training + PyTorch compile.
- base training + FSDP.
But it doesn't work when I combine both FSDP + PyTorch compile.
I have created an example here to reproduce the problem: https://colab.research.google.com/drive/1RmarhGBIjeWHIngO7fAp239eqt5Za8bZ?usp=sharing
Expected behavior
The script should work using both FSDP + PyTorch compile.
I'm not sure PyTorch XLA supports torch.compile + FSDP yet.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@agemagician, did you resolve this issue? If so, could you share the details with me?
No, it is not supported yet .