TeaCache icon indicating copy to clipboard operation
TeaCache copied to clipboard

Teacache with MultiGPU (FSDP) for Wan2.1

Open mali-afridi opened this issue 4 months ago • 0 comments

I see the examples on how to run it on one GPU but when I try something like this:

torchrun --nproc_per_node=8 teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --teacache_thresh 0.06 --use_ret_steps --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --save_file "ali_tf32.mp4"

I get the following error. Can we support it with FSDP? If so, how? I will be happy to contribute.

[2025-09-11 18:17:00,012] INFO: Generating video ...
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank2]:     generate(args)
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank2]:     video = wan_t2v.generate(
[rank2]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank2]:     context = self.text_encoder([input_prompt], self.device)
[rank2]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank2]:     context = self.model(ids, mask)
[rank2]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank5]: Traceback (most recent call last):
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank5]:     generate(args)
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank5]:     video = wan_t2v.generate(
[rank5]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank5]:     context = self.text_encoder([input_prompt], self.device)
[rank5]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank5]:     context = self.model(ids, mask)
[rank5]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank1]:     generate(args)
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank1]:     video = wan_t2v.generate(
[rank1]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank1]:     context = self.text_encoder([input_prompt], self.device)
[rank1]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank1]:     context = self.model(ids, mask)
[rank1]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank4]: Traceback (most recent call last):
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank4]:     generate(args)
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank4]:     video = wan_t2v.generate(
[rank4]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank4]:     context = self.text_encoder([input_prompt], self.device)
[rank4]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank4]:     context = self.model(ids, mask)
[rank4]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank4]:     return self._call_impl(*args, **kwargs)
[rank4]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank4]:     return forward_call(*args, **kwargs)
[rank4]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank7]:     generate(args)
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank7]:     video = wan_t2v.generate(
[rank7]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank7]:     context = self.text_encoder([input_prompt], self.device)
[rank7]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank7]:     context = self.model(ids, mask)
[rank7]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank7]:     return forward_call(*args, **kwargs)
[rank7]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank6]: Traceback (most recent call last):
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank6]:     generate(args)
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank6]:     video = wan_t2v.generate(
[rank6]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank6]:     context = self.text_encoder([input_prompt], self.device)
[rank6]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank6]:     context = self.model(ids, mask)
[rank6]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank6]:     return self._call_impl(*args, **kwargs)
[rank6]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank6]:     return forward_call(*args, **kwargs)
[rank6]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank0]:     generate(args)
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank0]:     video = wan_t2v.generate(
[rank0]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank0]:     context = self.text_encoder([input_prompt], self.device)
[rank0]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank0]:     context = self.model(ids, mask)
[rank0]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 1024, in <module>
[rank3]:     generate(args)
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 897, in generate
[rank3]:     video = wan_t2v.generate(
[rank3]:   File "/home/afridi/Wan2.1/teacache_generate.py", line 114, in t2v_generate
[rank3]:     context = self.text_encoder([input_prompt], self.device)
[rank3]:   File "/home/afridi/Wan2.1/wan/modules/t5.py", line 512, in __call__
[rank3]:     context = self.model(ids, mask)
[rank3]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:   File "/home/afridi/basic/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]: TypeError: teacache_forward() missing 2 required positional arguments: 'context' and 'seq_len'
[rank0]:[W911 18:17:03.490515524 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

mali-afridi avatar Sep 12 '25 01:09 mali-afridi