DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

8卡4090*48g, lora微调wan2.1 14B I2V模型OOM,怎么解决?

Open qiji2023 opened this issue 9 months ago • 19 comments

qiji2023 avatar Apr 15 '25 08:04 qiji2023

加了use_gradient_checkpointing_offload和training_strategy deepspeed_stage_3 同样会OOM

qiji2023 avatar Apr 15 '25 08:04 qiji2023

@Artiprocher @wenmengzhou Could you help me?thanks!

qiji2023 avatar Apr 15 '25 08:04 qiji2023

请问i2v的训练代码在哪里,直接基于仓库的t2v代码,把load的模型改成i2v模型,数据里图像改成视频进行训练吗?谢谢

mz2sj avatar Apr 15 '25 13:04 mz2sj

@qiji2023 I2V 模型需要 80G 显存的显卡才能进行微调

Artiprocher avatar Apr 16 '25 02:04 Artiprocher

@mz2sj 是的

Artiprocher avatar Apr 16 '25 02:04 Artiprocher

@Artiprocher 你们是否有支持(模型)流水并行微调的计划,把模型分割到多个显卡上来训练,普通玩家很难有80G显存的,拿出24GB都很不错了

qiji2023 avatar Apr 16 '25 02:04 qiji2023

@Artiprocher 有没有qlora的方法,起码让48G的卡跑起来呗

qiji2023 avatar Apr 16 '25 06:04 qiji2023

@Artiprocher 我尝试了一下80GB的A100 也会OOM,请问怎么解决

qiji2023 avatar Apr 17 '25 05:04 qiji2023

@Artiprocher 我尝试了一下80GB的A100 也会OOM,请问怎么解决

这是我的训练命令 python ./train_wan_t2v.py \ --task train \ --training_strategy deepspeed_stage_3 \ --train_architecture lora \ --dataset_path ./data/ \ --output_path ./models \ --dit_path "xxx" \ --max_epochs 10 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 8 \ --lora_target_modules "q" \ --accumulate_grad_batches 1

qiji2023 avatar Apr 17 '25 05:04 qiji2023

@qiji2023 增加参数 --use_gradient_checkpointing--use_gradient_checkpointing_offload

Artiprocher avatar Apr 17 '25 08:04 Artiprocher

@qiji2023 增加参数 --use_gradient_checkpointing--use_gradient_checkpointing_offload

@Artiprocher 我添加了这两个arg会产生error torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. 方便加个联系方式,帮忙看一下嘛,感谢感谢

qiji2023 avatar Apr 17 '25 08:04 qiji2023

@Artiprocher 老哥,有解决方法嘛,求!

qiji2023 avatar Apr 17 '25 09:04 qiji2023

@Artiprocher 报错信息如下:

root@4d32d75c4ed5:~/data/DiffSynth-Studio/examples/wanvideo# ./lora_train.sh 
17 videos in metadata.
17 tensors cached in metadata.
Loading models from: ['/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors']
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
    The following models are loaded: ['wan_video_dit'].
No wan_video_text_encoder models available.
Using wan_video_dit from ['/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors'].
No wan_video_vae models available.
No wan_video_image_encoder models available.
No wan_video_motion_controller models available.
No wan_video_vace models available.
/usr/local/lib/python3.10/dist-packages/lightning/fabric/connector.py:571: `precision=bf16` is supported for historical reasons but its usage is discouraged. Please set your precision to bf16-mixed instead!
[2025-04-17 09:36:45,623] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/2
17 videos in metadata.
17 tensors cached in metadata.
Loading models from: ['/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors']
    model_name: wan_video_dit model_class: WanModel
        This model is initialized with extra kwargs: {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
    The following models are loaded: ['wan_video_dit'].
No wan_video_text_encoder models available.
Using wan_video_dit from ['/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors', '/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors'].
No wan_video_vae models available.
No wan_video_image_encoder models available.
No wan_video_motion_controller models available.
No wan_video_vace models available.
[2025-04-17 09:37:08,482] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
initializing deepspeed distributed: GLOBAL_RANK: 1, MEMBER: 2/2
Enabling DeepSpeed BF16. Model parameters and inputs will be cast to `bfloat16`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
Parameter Offload: Total persistent parameters: 12107584 in 974 params

  | Name | Type             | Params | Params per Device | Mode
---------------------------------------------------------------------
0 | pipe | WanVideoPipeline | 16.4 B | 8.2 B             | eval
---------------------------------------------------------------------
6.6 M     Trainable params
16.4 B    Non-trainable params
16.4 B    Total params
65,605.730Total estimated model params size (MB)
1944      Modules in train mode
1         Modules in eval mode
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
Epoch 0:   0%|                                                                                                      | 0/250 [00:00, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/root/data/DiffSynth-Studio/examples/wanvideo/./train_wan_t2v.py", line 593, in 
[rank0]:     train(args)
[rank0]:   File "/root/data/DiffSynth-Studio/examples/wanvideo/./train_wan_t2v.py", line 585, in train
[rank0]:     trainer.fit(model, dataloader)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank0]:     self.advance()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
[rank0]:     self.epoch_loop.run(self._data_fetcher)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
[rank0]:     self.advance(data_fetcher)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance
[rank0]:     batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank0]:     self._optimizer_step(batch_idx, closure)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank0]:     call._call_lightning_module_hook(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step
[rank0]:     optimizer.step(closure=optimizer_closure)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank0]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
[rank0]:     optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank0]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank0]:     closure_result = closure()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank0]:     self._result = self.closure(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in closure
[rank0]:     self._backward_fn(step_output.closure_loss)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 241, in backward_fn
[rank0]:     call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward
[rank0]:     self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 117, in backward
[rank0]:     deepspeed_engine.backward(tensor, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2216, in backward
[rank0]:     self._do_optimizer_backward(loss, retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward
[rank0]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2280, in backward
[rank0]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]:     scaled_loss.backward(retain_graph=retain_graph)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 626, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 1129, in unpack_hook
[rank0]:     frame.check_recomputed_tensors_match(gid)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
[rank0]:     raise CheckpointError(
[rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank0]: tensor at position 13:
[rank0]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 19:
[rank0]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 32:
[rank0]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 34:
[rank0]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: tensor at position 46:
[rank0]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}
[rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)}

[rank1]: Traceback (most recent call last):
[rank1]:   File "/root/data/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 593, in 
[rank1]:     train(args)
[rank1]:   File "/root/data/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 585, in train
[rank1]:     trainer.fit(model, dataloader)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
[rank1]:     call._call_and_handle_interrupt(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank1]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank1]:     return function(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
[rank1]:     self._run(model, ckpt_path=ckpt_path)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank1]:     results = self._run_stage()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
[rank1]:     self.fit_loop.run()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank1]:     self.advance()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
[rank1]:     self.epoch_loop.run(self._data_fetcher)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
[rank1]:     self.advance(data_fetcher)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance
[rank1]:     batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank1]:     self._optimizer_step(batch_idx, closure)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank1]:     call._call_lightning_module_hook(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step
[rank1]:     optimizer.step(closure=optimizer_closure)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank1]:     step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
[rank1]:     optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank1]:     return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank1]:     closure_result = closure()
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank1]:     self._result = self.closure(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in closure
[rank1]:     self._backward_fn(step_output.closure_loss)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 241, in backward_fn
[rank1]:     call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank1]:     output = fn(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward
[rank1]:     self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 117, in backward
[rank1]:     deepspeed_engine.backward(tensor, *args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2216, in backward
[rank1]:     self._do_optimizer_backward(loss, retain_graph)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward
[rank1]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2280, in backward
[rank1]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank1]:     scaled_loss.backward(retain_graph=retain_graph)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 626, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 1129, in unpack_hook
[rank1]:     frame.check_recomputed_tensors_match(gid)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 903, in check_recomputed_tensors_match
[rank1]:     raise CheckpointError(
[rank1]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank1]: tensor at position 13:
[rank1]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 19:
[rank1]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 32:
[rank1]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 34:
[rank1]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: tensor at position 46:
[rank1]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=1)}

root@4d32d75c4ed5:~/data/DiffSynth-Studio/examples/wanvideo# 

qiji2023 avatar Apr 17 '25 09:04 qiji2023

@Artiprocher Hi,可以帮忙看一下这个问题嘛?我看貌似其他人也提到这个问题了

qiji2023 avatar Apr 17 '25 13:04 qiji2023

你不是用到了deepspeed_stage_zero3,换zero2应该可以跑 @qiji2023

Rookienovice avatar Apr 18 '25 02:04 Rookienovice

@Artiprocher @Rookienovice 训练可以进行了,我完全按照文档给的去测试lora,但是又遇到问题了,可以帮忙看一眼嘛,感谢感谢!

# Load models
model_manager = ModelManager(device="cpu")
model_manager.load_models(
    ["/root/data/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
    torch_dtype=torch.float32, # Image Encoder is loaded with float32
)
model_manager.load_models(
    [
        [
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors",
            "/root/data/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors",
        ],
        "/root/data/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth",
        "/root/data/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth",
    ],
    torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
)
model_manager.load_lora("/root/data/DiffSynth-Studio/examples/wanvideo/models/lightning_logs/version_14/checkpoints/epoch=4-step=250.ckpt", lora_alpha=1.0)

pipe = WanVideoPipeline.from_model_manager(model_manager, 
                                           torch_dtype=torch.bfloat16,
                                           device=f"cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.

# Download example image
# dataset_snapshot_download(
#     dataset_id="DiffSynth-Studio/example`s_in_diffsynth",
#     local_dir="./",
#     allow_file_pattern=f"data/examples/wan/input_image.jpg"
# )
w = 480
h = 832
image = Image.open("./c.png").convert("RGB")
resize = transforms.Resize([h, w])
image = resize(image)

# Image-to-video
video = pipe(
    prompt="360Rotate",
    negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    input_image=image,
    num_inference_steps=50,
    seed=2025,
    height=h,
    width=w ,
    num_frames=81,
    tiled=True
)
save_video(video, "video.mp4", fps=16, quality=9)

error

Loading LoRA models from file: /root/data/DiffSynth-Studio/examples/wanvideo/models/lightning_logs/version_14/checkpoints/epoch=4-step=250.ckpt
Traceback (most recent call last):
  File "/root/data/DiffSynth-Studio/examples/wanvideo/./test_lora.py", line 32, in 
    model_manager.load_lora("/root/data/DiffSynth-Studio/examples/wanvideo/models/lightning_logs/version_14/checkpoints/epoch=4-step=250.ckpt", lora_alpha=1.0)
  File "/root/data/DiffSynth-Studio/diffsynth/models/model_manager.py", line 381, in load_lora
    state_dict = load_state_dict(file_path)
  File "/root/data/DiffSynth-Studio/diffsynth/models/utils.py", line 69, in load_state_dict
    return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
  File "/root/data/DiffSynth-Studio/diffsynth/models/utils.py", line 83, in load_state_dict_from_bin
    state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 1425, in load
    with _open_file_like(f, "rb") as opened_file:
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 751, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 732, in __init__
    super().__init__(open(name, mode))
IsADirectoryError: [Errno 21] Is a directory: '/root/data/DiffSynth-Studio/examples/wanvideo/models/lightning_logs/version_14/checkpoints/epoch=4-step=250.ckpt'

qiji2023 avatar Apr 18 '25 06:04 qiji2023

请问80G的卡的正确的训练sh是怎么样的,我用80G卡做full train也OOM了

huangjch526 avatar May 01 '25 06:05 huangjch526

你不是用到了deepspeed_stage_zero3,换zero2应该可以跑 @qiji2023

@Rookienovice 您好。请问为何 stage_3 不能用?stage_2 对单卡显存要求太高了,顶不住。

huqng avatar May 17 '25 06:05 huqng

@Artiprocher @Rookienovice 训练可以进行了,我完全按照文档给的去测试lora,但是又遇到问题了,可以帮忙看一眼嘛,感谢感谢!

@qiji2023 请问最后是用A100训练的,还是8*4090训练的?我用8*4090 LoRA微调Wan2.1-T2V-14B模型也一直OOM

qh-deng avatar May 20 '25 06:05 qh-deng