When I set more than 2 gpus (4 or 6), I will get a tensor size error, but when I set it to 2 it works will. is there some solution to solve this problem?
this is my data setting:
dataloader = torch.utils.data.DataLoader(
ToyDataset([
{
"prompt":"....",
"negative_prompt":"....",
"num_inference_steps": 500,
"seed": 0,
"tiled": False,
"height": 720,
"width": 480,
"output_path": "video_test1.mp4",
},
]),
collate_fn=lambda x: x,
num_workers=64,
pin_memory=True
)
when using "CUDA_VISIBLE_DEVICES="4,5,6,7" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py"
occured error:
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 166, in
[rank2]: trainer.test(model, dataloader)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 775, in test
[rank2]: return call._call_and_handle_interrupt(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank2]: return trainer_fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 817, in _test_impl
[rank2]: results = self._run(model, ckpt_path=ckpt_path)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank2]: results = self._run_stage()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1049, in _run_stage
[rank2]: return self._evaluation_loop.run()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
[rank2]: return loop_run(self, *args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
[rank2]: self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
[rank2]: output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank2]: output = fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 425, in test_step
[rank2]: return self.lightning_module.test_step(*args, **kwargs)
[rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 120, in test_step
[rank2]: video = self.pipe(**data)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 286, in call
[rank2]: noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 407, in model_fn_wan_video
[rank2]: x = block(x, context, t_mod, freqs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank2]: return inner()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 216, in forward
[rank2]: x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank2]: return inner()
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank2]: result = forward_call(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 141, in forward
[rank2]: q = rope_apply(q, freqs, self.num_heads)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 93, in rope_apply
[rank2]: x_out = torch.view_as_real(x_out * freqs).flatten(2)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank2]: return disable_fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 346, in torch_dispatch
[rank2]: return DTensor._op_dispatcher.dispatch(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 170, in dispatch
[rank2]: self.sharding_propagator.propagate(op_info)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate
[rank2]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call
[rank2]: return self.cache(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 219, in propagate_op_sharding_non_cached
[rank2]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 123, in _propagate_tensor_meta_non_cached
[rank2]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch
[rank2]: return self.dispatch(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank2]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank2]: output = self._dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
[rank2]: r = func(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank2]: result = fn(*args, **kwargs)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 143, in _fn
[rank2]: result = fn(**bound.arguments)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 1095, in _ref
[rank2]: a, b = _maybe_broadcast(a, b)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 437, in _maybe_broadcast
[rank2]: common_shape = _broadcast_shapes(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 425, in _broadcast_shapes
[rank2]: torch._check(
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1656, in _check
[rank2]: _check_with(RuntimeError, cond, message)
[rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1638, in _check_with
[rank2]: raise error_type(message_evaluated)
[rank2]: RuntimeError: Attempting to broadcast a dimension of length 28350 at -3! Mismatching argument at index 1 had torch.Size([28350, 1, 64]); but expected shape should be broadcastable to [1, 28352, 40, 64]
Testing DataLoader 0: 0%| | 0/3 [00:03<?, ?it/s]
[W327 12:46:40.142966460 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
- c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
- c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under
with allow_inflight_collective_as_graph_input_ctx():,
before the output tensors of the collective are used. (function ~WorkRegistry)
[W327 12:46:40.362266633 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
- c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
- c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under
with allow_inflight_collective_as_graph_input_ctx():,
before the output tensors of the collective are used. (function ~WorkRegistry)
[W327 12:46:41.056580355 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
- c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
- c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under
with allow_inflight_collective_as_graph_input_ctx():,
before the output tensors of the collective are used. (function ~WorkRegistry)
[W327 12:46:41.080732705 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
- c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
- c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under
with allow_inflight_collective_as_graph_input_ctx():,
before the output tensors of the collective are used. (function ~WorkRegistry)
when using "CUDA_VISIBLE_DEVICES="4,5" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py"
works well
some time errors like this:
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 166, in
[rank3]: trainer.test(model, dataloader)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 775, in test
[rank3]: return call._call_and_handle_interrupt(
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank3]: return trainer_fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 817, in _test_impl
[rank3]: results = self._run(model, ckpt_path=ckpt_path)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank3]: results = self._run_stage()
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1049, in _run_stage
[rank3]: return self._evaluation_loop.run()
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
[rank3]: return loop_run(self, *args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
[rank3]: self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
[rank3]: output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
[rank3]: output = fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 425, in test_step
[rank3]: return self.lightning_module.test_step(*args, **kwargs)
[rank3]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 120, in test_step
[rank3]: video = self.pipe(**data)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank3]: return func(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 286, in call
[rank3]: noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 378, in model_fn_wan_video
[rank3]: context = dit.text_embedding(context)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
[rank3]: input = module(input)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank3]: return inner()
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank3]: result = forward_call(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
[rank3]: return F.linear(input, self.weight, self.bias)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank3]: return disable_fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank3]: return fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 346, in torch_dispatch
[rank3]: return DTensor._op_dispatcher.dispatch(
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 170, in dispatch
[rank3]: self.sharding_propagator.propagate(op_info)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate
[rank3]: OutputSharding, self.propagate_op_sharding(op_info.schema)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call
[rank3]: return self.cache(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 219, in propagate_op_sharding_non_cached
[rank3]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 123, in _propagate_tensor_meta_non_cached
[rank3]: fake_out = op_schema.op(*fake_args, **fake_kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank3]: return self._op(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank3]: return fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch
[rank3]: return self.dispatch(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank3]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank3]: output = self._dispatch_impl(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
[rank3]: decomposition_table[func](*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank3]: result = fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 83, in inner
[rank3]: r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1450, in addmm
[rank3]: out = alpha * torch.mm(mat1, mat2)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank3]: return fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch
[rank3]: return self.dispatch(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank3]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank3]: output = self._dispatch_impl(func, types, args, kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl
[rank3]: r = func(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
[rank3]: return self._op(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank3]: result = fn(*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2127, in meta_mm
[rank3]: torch._check(
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1656, in _check
[rank3]: _check_with(RuntimeError, cond, message)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1638, in _check_with
[rank3]: raise error_type(message_evaluated)
[rank3]: RuntimeError: a and b must have same reduction dim, but got [512, 5124] X [5120, 5120].
Testing DataLoader 0: 0%| | 0/3 [00:07<?, ?it/s]