省流:Recomputed values for the following tensors have different metadata than during the forward pass.
命令如下:
python examples/wanvideo/train_wan_t2v.py \
--task train \
--train_architecture lora \
--dataset_path /mnt/data/wan_lora \
--output_path ./models \
--dit_path "/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors,/mnt/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors" \
--num_frames 21 \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 8 \
--lora_target_modules "q" \
--accumulate_grad_batches 1 \
--training_strategy deepspeed_stage_3 \
--use_gradient_checkpointing \
--use_gradient_checkpointing_offload
完整报错信息如下:
[rank4]: Traceback (most recent call last):
[rank4]: File "/mnt/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 593, in
[rank4]: train(args)
[rank4]: File "/mnt/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 585, in train
[rank4]: trainer.fit(model, dataloader)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
[rank4]: call._call_and_handle_interrupt(
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank4]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]: return function(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
[rank4]: self._run(model, ckpt_path=ckpt_path)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank4]: results = self._run_stage()
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
[rank4]: self.fit_loop.run()
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank4]: self.advance()
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
[rank4]: self.epoch_loop.run(self._data_fetcher)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
[rank4]: self.advance(data_fetcher)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 320, in advance
[rank4]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank4]: self._optimizer_step(batch_idx, closure)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank4]: call._call_lightning_module_hook(
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 176, in _call_lightning_module_hook
[rank4]: output = fn(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py", line 1302, in optimizer_step
[rank4]: optimizer.step(closure=optimizer_closure)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank4]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
[rank4]: optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank4]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 129, in optimizer_step
[rank4]: closure_result = closure()
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in call
[rank4]: self._result = self.closure(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in closure
[rank4]: self._backward_fn(step_output.closure_loss)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 241, in backward_fn
[rank4]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank4]: output = fn(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward
[rank4]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/deepspeed.py", line 117, in backward
[rank4]: deepspeed_engine.backward(tensor, *args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank4]: ret_val = func(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2216, in backward
[rank4]: self._do_optimizer_backward(loss, retain_graph)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2162, in _do_optimizer_backward
[rank4]: self.optimizer.backward(loss, retain_graph=retain_graph)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank4]: ret_val = func(*args, **kwargs)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2280, in backward
[rank4]: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank4]: scaled_loss.backward(retain_graph=retain_graph)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 648, in backward
[rank4]: torch.autograd.backward(
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py", line 353, in backward
[rank4]: _engine_run_backward(
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank4]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 1128, in unpack_hook
[rank4]: frame.check_recomputed_tensors_match(gid)
[rank4]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match
[rank4]: raise CheckpointError(
[rank4]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank4]: tensor at position 13:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: tensor at position 23:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: tensor at position 42:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: tensor at position 44:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: tensor at position 56:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: tensor at position 64:
[rank4]: saved metadata: {'shape': torch.Size([5120]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
[rank4]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.float16, 'device': device(type='cuda', index=4)}
Hi @Artiprocher
I used your new framework with accelerate. I also encountered similar errors, when enabling zero_3 and use_gradient_checkpointing.
[rank0]: raise CheckpointError( [rank0]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. [rank0]: tensor at position 13: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 23: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 52: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 62: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 91: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 101: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 130: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'shape': torch.Size([0]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: tensor at position 140: [rank0]: saved metadata: {'shape': torch.Size([1536]), 'dtype': torch.bfloat16, 'device': device(type='cuda', index=0)} [rank0]: recomputed metadata: {'