DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

训练时,同时启用 `deepspeed_stage_3` 和 `use_gradient_checkpointing` 以及 `use_gradient_checkpointing_offload` 会报错

Open huqng opened this issue 8 months ago • 3 comments

省流:Recomputed values for the following tensors have different metadata than during the forward pass.

命令如下:

python examples/wanvideo/train_wan_t2v.py \
   --task train \
   --train_architecture lora \
   --dataset_path /mnt/data/wan_lora \
   --output_path ./models \
   --dit_path "/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors" \
   --num_frames 21 \
   --steps_per_epoch 500 \
   --max_epochs 10 \
   --learning_rate 1e-4 \
   --lora_rank 8 \
   --lora_alpha 8 \
   --lora_target_modules "q" \
   --accumulate_grad_batches 1 \
   --training_strategy deepspeed_stage_3 \
   --use_gradient_checkpointing \
   --use_gradient_checkpointing_offload

完整报错信息如下: [rank4]: Traceback (most recent call last): [rank4]: File "/mnt/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 593, in [rank4]: train(args) [rank4]: File "/mnt/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 585, in train [rank4]: trainer.fit(model, dataloader) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit [rank4]: call._call_and_handle_interrupt( [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt [rank4]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch [rank4]: return function(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl [rank4]: self._run(model, ckpt_path=ckpt_path) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run [rank4]: results = self._run_stage() [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage [rank4]: self.fit_loop.run() [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run [rank4]: self.advance() [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance [rank4]: self.epoch_loop.run(self._data_fetcher) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run [rank4]: self.advance(data_fetcher) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance [rank4]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run [rank4]: self._optimizer_step(batch_idx, closure) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step [rank4]: call._call_lightning_module_hook( [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook [rank4]: output = fn(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step [rank4]: optimizer.step(closure=optimizer_closure) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/optimizer.py", line 154, in step [rank4]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step [rank4]: optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step [rank4]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step [rank4]: closure_result = closure() [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in call [rank4]: self._result = self.closure(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank4]: return func(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in closure [rank4]: self._backward_fn(step_output.closure_loss) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 241, in backward_fn [rank4]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook [rank4]: output = fn(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward [rank4]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 117, in backward [rank4]: deepspeed_engine.backward(tensor, *args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn [rank4]: ret_val = func(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2216, in backward [rank4]: self._do_optimizer_backward(loss, retain_graph) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward [rank4]: self.optimizer.backward(loss, retain_graph=retain_graph) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn [rank4]: ret_val = func(*args, **kwargs) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2280, in backward [rank4]: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) [rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward [rank4]: scaled_loss.backward(retain_graph=retain_graph) [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 648, in backward [rank4]: torch.autograd.backward( [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py", line 353, in backward [rank4]: _engine_run_backward( [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward [rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 1128, in unpack_hook [rank4]: frame.check_recomputed_tensors_match(gid) [rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match [rank4]: raise CheckpointError( [rank4]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. [rank4]: tensor at position 13: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: tensor at position 23: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: tensor at position 42: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: tensor at position 44: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: tensor at position 56: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: tensor at position 64: [rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)} [rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}

huqng avatar May 17 '25 08:05 huqng

I also encountered a similar error. My solution was to switch to deepspeed2. I hope my suggestions can help you.

zfw-cv avatar May 19 '25 06:05 zfw-cv

I also encountered a similar error. My solution was to switch to deepspeed2. I hope my suggestions can help you.

@zfw-cv Thank you, but using deepspeed_stage_2 will cause an OOM on my device.

huqng avatar May 21 '25 04:05 huqng

I also encountered a similar error. My solution was to switch to deepspeed2. I hope my suggestions can help you.

@zfw-cv Thank you, but using deepspeed_stage_2 will cause an OOM on my device. This may be because you set the video frame rate to be too large. As a multiple of 8 plus 1, you can try setting it to 17, 25, 33, etc.

zfw-cv avatar May 22 '25 02:05 zfw-cv

+1,我也遇到了,哥们有解决吗

njzxj avatar Jun 03 '25 03:06 njzxj

@huqng @zfw-cv @njzxj This problem is caused by lightning. We have updated our training and inference framework. The new framework is based on accelerate. Feel free to try them out!

Artiprocher avatar Jun 24 '25 05:06 Artiprocher

Hi @Artiprocher I used your new framework with accelerate. I also encountered similar errors, when enabling zero_3 and use_gradient_checkpointing.

[rank0]: raise CheckpointError( [rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. [rank0]: tensor at position 13: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 23: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 52: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 62: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 91: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 101: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 130: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 140: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'

zoezhou1999 avatar Sep 26 '25 20:09 zoezhou1999