DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

tensor size error when using wan_14b_text_to_video_tensor_parallel in more than 2 gpus

Open meareabc opened this issue 10 months ago • 2 comments

When I set more than 2 gpus (4 or 6), I will get a tensor size error, but when I set it to 2 it works will. is there some solution to solve this problem?

this is my data setting:

dataloader = torch.utils.data.DataLoader( ToyDataset([ { "prompt":"....", "negative_prompt":"....", "num_inference_steps": 500, "seed": 0, "tiled": False, "height": 720, "width": 480, "output_path": "video_test1.mp4", }, ]), collate_fn=lambda x: x, num_workers=64, pin_memory=True )

when using "CUDA_VISIBLE_DEVICES="4,5,6,7" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py" occured error: [rank2]: Traceback (most recent call last): [rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 166, in [rank2]: trainer.test(model, dataloader) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 775, in test [rank2]: return call._call_and_handle_interrupt( [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt [rank2]: return trainer_fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 817, in _test_impl [rank2]: results = self._run(model, ckpt_path=ckpt_path) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run [rank2]: results = self._run_stage() [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1049, in _run_stage [rank2]: return self._evaluation_loop.run() [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator [rank2]: return loop_run(self, *args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run [rank2]: self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step [rank2]: output = call._call_strategy_hook(trainer, hook_name, *step_args) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook [rank2]: output = fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 425, in test_step [rank2]: return self.lightning_module.test_step(*args, **kwargs) [rank2]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 120, in test_step [rank2]: video = self.pipe(**data) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank2]: return func(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 286, in call [rank2]: noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 407, in model_fn_wan_video [rank2]: x = block(x, context, t_mod, freqs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank2]: return self._call_impl(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl [rank2]: return inner() [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner [rank2]: result = forward_call(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 216, in forward [rank2]: x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank2]: return self._call_impl(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl [rank2]: return inner() [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner [rank2]: result = forward_call(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 141, in forward [rank2]: q = rope_apply(q, freqs, self.num_heads) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/models/wan_video_dit.py", line 93, in rope_apply [rank2]: x_out = torch.view_as_real(x_out * freqs).flatten(2) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner [rank2]: return disable_fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn [rank2]: return fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 346, in torch_dispatch [rank2]: return DTensor._op_dispatcher.dispatch( [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 170, in dispatch [rank2]: self.sharding_propagator.propagate(op_info) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate [rank2]: OutputSharding, self.propagate_op_sharding(op_info.schema) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call [rank2]: return self.cache(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 219, in propagate_op_sharding_non_cached [rank2]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 123, in _propagate_tensor_meta_non_cached [rank2]: fake_out = op_schema.op(*fake_args, **fake_kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call [rank2]: return self._op(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper [rank2]: return fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch [rank2]: return self.dispatch(func, types, args, kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch [rank2]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl [rank2]: output = self._dispatch_impl(func, types, args, kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl [rank2]: r = func(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call [rank2]: return self._op(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn [rank2]: result = fn(*args, **kwargs) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 143, in _fn [rank2]: result = fn(**bound.arguments) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 1095, in _ref [rank2]: a, b = _maybe_broadcast(a, b) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 437, in _maybe_broadcast [rank2]: common_shape = _broadcast_shapes( [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_refs/init.py", line 425, in _broadcast_shapes [rank2]: torch._check( [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1656, in _check [rank2]: _check_with(RuntimeError, cond, message) [rank2]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1638, in _check_with [rank2]: raise error_type(message_evaluated) [rank2]: RuntimeError: Attempting to broadcast a dimension of length 28350 at -3! Mismatching argument at index 1 had torch.Size([28350, 1, 64]); but expected shape should be broadcastable to [1, 28352, 40, 64] Testing DataLoader 0: 0%| | 0/3 [00:03<?, ?it/s] [W327 12:46:40.142966460 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:

  1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():, before the output tensors of the collective are used. (function ~WorkRegistry) [W327 12:46:40.362266633 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  3. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  4. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():, before the output tensors of the collective are used. (function ~WorkRegistry) [W327 12:46:41.056580355 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  5. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  6. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():, before the output tensors of the collective are used. (function ~WorkRegistry) [W327 12:46:41.080732705 ProcessGroup.cpp:266] Warning: At the time of process termination, there are still 1 unwaited collective calls. Please review your program to ensure that:
  7. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
  8. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under with allow_inflight_collective_as_graph_input_ctx():, before the output tensors of the collective are used. (function ~WorkRegistry)

when using "CUDA_VISIBLE_DEVICES="4,5" "python ./examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py" works well

meareabc avatar Mar 27 '25 05:03 meareabc

some time errors like this:

[rank3]: Traceback (most recent call last): [rank3]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 166, in [rank3]: trainer.test(model, dataloader) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 775, in test [rank3]: return call._call_and_handle_interrupt( [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt [rank3]: return trainer_fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 817, in _test_impl [rank3]: results = self._run(model, ckpt_path=ckpt_path) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run [rank3]: results = self._run_stage() [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1049, in _run_stage [rank3]: return self._evaluation_loop.run() [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator [rank3]: return loop_run(self, *args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run [rank3]: self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step [rank3]: output = call._call_strategy_hook(trainer, hook_name, *step_args) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook [rank3]: output = fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 425, in test_step [rank3]: return self.lightning_module.test_step(*args, **kwargs) [rank3]: File "/home/lkh/sd/DiffSynth-Studio/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py", line 120, in test_step [rank3]: video = self.pipe(**data) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context [rank3]: return func(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 286, in call [rank3]: noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/diffsynth/pipelines/wan_video.py", line 378, in model_fn_wan_video [rank3]: context = dit.text_embedding(context) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl [rank3]: return forward_call(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward [rank3]: input = module(input) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl [rank3]: return inner() [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner [rank3]: result = forward_call(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward [rank3]: return F.linear(input, self.weight, self.bias) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner [rank3]: return disable_fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn [rank3]: return fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 346, in torch_dispatch [rank3]: return DTensor._op_dispatcher.dispatch( [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 170, in dispatch [rank3]: self.sharding_propagator.propagate(op_info)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 206, in propagate [rank3]: OutputSharding, self.propagate_op_sharding(op_info.schema) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 46, in call [rank3]: return self.cache(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 219, in propagate_op_sharding_non_cached [rank3]: out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/distributed/tensor/_sharding_prop.py", line 123, in _propagate_tensor_meta_non_cached [rank3]: fake_out = op_schema.op(*fake_args, **fake_kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call [rank3]: return self._op(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper [rank3]: return fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch [rank3]: return self.dispatch(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch [rank3]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl [rank3]: output = self._dispatch_impl(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl [rank3]: decomposition_table[func](*args, **kwargs)
[rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn [rank3]: result = fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 83, in inner [rank3]: r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1450, in addmm [rank3]: out = alpha * torch.mm(mat1, mat2) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper [rank3]: return fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch [rank3]: return self.dispatch(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch [rank3]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl [rank3]: output = self._dispatch_impl(func, types, args, kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2384, in _dispatch_impl [rank3]: r = func(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_ops.py", line 723, in call [rank3]: return self._op(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn [rank3]: result = fn(*args, **kwargs) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2127, in meta_mm [rank3]: torch._check( [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1656, in _check [rank3]: _check_with(RuntimeError, cond, message) [rank3]: File "/home/lkh/anaconda3/envs/DiffSynth/lib/python3.10/site-packages/torch/init.py", line 1638, in _check_with [rank3]: raise error_type(message_evaluated) [rank3]: RuntimeError: a and b must have same reduction dim, but got [512, 5124] X [5120, 5120]. Testing DataLoader 0: 0%| | 0/3 [00:07<?, ?it/s]

meareabc avatar Mar 27 '25 07:03 meareabc

@meareabc In tensor parallel, the number of devices must be 2, 4, or 8. This is limited by PyTorch.

Artiprocher avatar Apr 03 '25 02:04 Artiprocher