NeMo
NeMo copied to clipboard
hf automodel train script fails with --enable-grad-ckpt or --sequence-parallel option enabled
Describe the bug
Hello thank you for developing Nemo framework.
I'm testing examples/llm/finetune/automodel.py with llama checkpoint(NousResearch/Llama-2-7b-hf).
baseline
- running script
#!/bin/bash
DEVICES=8,9,10,11
CUDA_VISIBLE_DEVICES=${DEVICES} torchrun --nproc-per-node=4 \
examples/llm/finetune/automodel.py \
--strategy fsdp2 \
--seq-length 2048 \
--devices 4 \
--model /home/ckpt/nous_research_llama2_7b_hf \
--ckpt-folder "output" \
--tp-size 1 \
--cp-size 2 \
--dp-size 2 \
--liger
# output
Epoch 0: 0%| | 79/43800 [00:39<6:06:25, 1.99it/s, global_step=4.000, reduced_train_loss=12.30, tps=9484.0, lr=3e-6]
Epoch 0: 0%| | 80/43800 [00:40<6:08:06, 1.98it/s, global_step=4.000, reduced_train_loss=12.30, tps=9484.0, lr=3e-6]
Epoch 0: 0%| | 80/43800 [00:40<6:08:07, 1.98it/s, global_step=4.000, reduced_train_loss=11.60, tps=9398.0, lr=3e-6]
Epoch 0: 0%| | 81/43800 [00:40<6:06:32, 1.99it/s, global_step=4.000, reduced_train_loss=11.60, tps=9398.0, lr=3e-6]
--enable-grad-ckpt option enabled
#!/bin/bash
DEVICES=8,9,10,11
CUDA_VISIBLE_DEVICES=${DEVICES} torchrun --nproc-per-node=4 \
examples/llm/finetune/automodel.py \
--strategy fsdp2 \
--seq-length 2048 \
--devices 4 \
--model /home/ckpt/nous_research_llama2_7b_hf \
--ckpt-folder "output" \
--tp-size 1 \
--cp-size 2 \
--dp-size 2 \
--liger \
--enable-grad-ckpt
# output
Epoch 0: 0%| | 0/43800 [00:00<?, ?it/s] `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
[rank1]: [rank1]: Traceback (most recent call last):
[rank1]: [rank1]: File "/home/test/NeMo/examples/llm/finetune/automodel.py", line 474, in <module>
[rank1]: [rank1]: main()
[rank1]: [rank1]: File "/home/test/NeMo/examples/llm/finetune/automodel.py", line 446, in main
[rank1]: [rank1]: llm.api.finetune(
[rank1]: [rank1]: File "/home/test/NeMo/nemo/collections/llm/api.py", line 236, in finetune
[rank1]: [rank1]: return train(
[rank1]: [rank1]: ^^^^^^
[rank1]: [rank1]: File "/home/test/NeMo/nemo/collections/llm/api.py", line 135, in train
[rank1]: [rank1]: trainer.fit(model, data)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank1]: [rank1]: call._call_and_handle_interrupt(
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank1]: [rank1]: return trainer_fn(*args, **kwargs)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank1]: [rank1]: self._run(model, ckpt_path=ckpt_path)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank1]: [rank1]: results = self._run_stage()
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank1]: [rank1]: self.fit_loop.run()
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank1]: [rank1]: self.advance()
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank1]: [rank1]: self.epoch_loop.run(self._data_fetcher)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank1]: [rank1]: self.advance(data_fetcher)
[rank1]: [rank1]: File "/home/test/NeMo/nemo/lightning/pytorch/trainer.py", line 47, in advance
[rank1]: [rank1]: super().advance(data_fetcher)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
[rank1]: [rank1]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 183, in run
[rank1]: [rank1]: closure()
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
[rank1]: [rank1]: self._result = self.closure(*args, **kwargs)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: [rank1]: return func(*args, **kwargs)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 138, in closure
[rank1]: [rank1]: self._backward_fn(step_output.closure_loss)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 239, in backward_fn
[rank1]: [rank1]: call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
[rank1]: [rank1]: output = fn(*args, **kwargs)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/strategies/strategy.py", line 212, in backward
[rank1]: [rank1]: self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/plugins/precision/precision.py", line 72, in backward
[rank1]: [rank1]: model.backward(tensor, *args, **kwargs)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/core/module.py", line 1101, in backward
[rank1]: [rank1]: loss.backward(*args, **kwargs)
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 648, in backward
[rank1]: [rank1]: torch.autograd.backward(
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 353, in backward
[rank1]: [rank1]: _engine_run_backward(
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank1]: [rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
[rank1]: [rank1]: return user_fn(self, *args)
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/checkpoint.py", line 278, in backward
[rank1]: [rank1]: tensors = ctx.saved_tensors
[rank1]: [rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: [rank1]: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 2048]] is at version 4; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
--sequence-parallel option enabled
#!/bin/bash
DEVICES=8,9,10,11
CUDA_VISIBLE_DEVICES=${DEVICES} torchrun --nproc-per-node=4 \
examples/llm/finetune/automodel.py \
--strategy fsdp2 \
--seq-length 2048 \
--devices 4 \
--model /home/ckpt/nous_research_llama2_7b_hf \
--ckpt-folder "output" \
--tp-size 2 \
--cp-size 2 \
--dp-size 1 \
--liger \
--sequence-parallel
# output
Training: 0%| | 0/87599 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/87599 [00:00<?, ?it/s] [rank0]: [rank0]: Traceback (most recent call last):
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank0]: [rank0]: return trainer_fn(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank0]: [rank0]: self._run(model, ckpt_path=ckpt_path)
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank0]: [rank0]: results = self._run_stage()
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank0]: [rank0]: self.fit_loop.run()
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank0]: [rank0]: self.advance()
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank0]: [rank0]: self.epoch_loop.run(self._data_fetcher)
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank0]: [rank0]: self.advance(data_fetcher)
[rank0]: [rank0]: File "/home/test/NeMo/nemo/lightning/pytorch/trainer.py", line 47, in advance
[rank0]: [rank0]: super().advance(data_fetcher)
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
[rank0]: [rank0]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 183, in run
[rank0]: [rank0]: closure()
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
[rank0]: [rank0]: self._result = self.closure(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: [rank0]: return func(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
[rank0]: [rank0]: step_output = self._step_fn()
[rank0]: [rank0]: ^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
[rank0]: [rank0]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
[rank0]: [rank0]: output = fn(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/home/test/NeMo/nemo/lightning/pytorch/strategies/fsdp2_strategy.py", line 404, in training_step
[rank0]: [rank0]: loss = self.lightning_module.training_step(batch, batch_idx, context_parallel=True)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/home/test/NeMo/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py", line 370, in training_step
[rank0]: [rank0]: outputs = self.forward(batch)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/home/test/NeMo/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py", line 298, in forward
[rank0]: [rank0]: return self.model(**batch)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: [rank0]: return self._call_impl(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank0]: [rank0]: return inner()
[rank0]: [rank0]: ^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1805, in inner
[rank0]: [rank0]: result = forward_call(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]: [rank0]: return func(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/liger_kernel/transformers/model/llama.py", line 196, in lce_forward
[rank0]: [rank0]: outputs = self.model(
[rank0]: [rank0]: ^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: [rank0]: return self._call_impl(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]: [rank0]: return forward_call(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]: [rank0]: output = func(self, *args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 571, in forward
[rank0]: [rank0]: layer_outputs = decoder_layer(
[rank0]: [rank0]: ^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: [rank0]: return self._call_impl(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank0]: [rank0]: return inner()
[rank0]: [rank0]: ^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1805, in inner
[rank0]: [rank0]: result = forward_call(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 318, in forward
[rank0]: [rank0]: hidden_states, self_attn_weights = self.self_attn(
[rank0]: [rank0]: ^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: [rank0]: return self._call_impl(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]: [rank0]: return forward_call(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama/modeling_llama.py", line 252, in forward
[rank0]: [rank0]: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]: [rank0]: return self._call_impl(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl
[rank0]: [rank0]: return inner()
[rank0]: [rank0]: ^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1805, in inner
[rank0]: [rank0]: result = forward_call(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py", line 125, in forward
[rank0]: [rank0]: return F.linear(input, self.weight, self.bias)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_compile.py", line 51, in inner
[rank0]: [rank0]: return disable_fn(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 749, in _fn
[rank0]: [rank0]: return fn(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_api.py", line 348, in __torch_dispatch__
[rank0]: [rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/distributed/tensor/_dispatch.py", line 217, in dispatch
[rank0]: [rank0]: local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 756, in __call__
[rank0]: [rank0]: return self._op(*args, **kwargs)
[rank0]: [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: [rank0]: RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
Environment overview (please complete the following information)
- Method of NeMo install: [pip install or from source]. Please specify exact commands you used to install.
- If method of install is [Docker], provide
docker pull&docker runcommands used: nvcr.io/nvidia/nemo:25.04
Environment details
If NVIDIA docker image is used you don't need to specify these. Otherwise, please provide:
- OS version
- PyTorch version
- Python version
flash_attn 2.3.4
flashinfer-python 0.2.5
open-clip-torch 2.24.0
pytorch-lightning 2.5.1.post0
pytorch-triton 3.2.0+gitb2684bf3b.nvinternal
torch 2.7.0a0+7c8ec84dab.nv25.3
torch-geometric 2.6.1
torch_tensorrt 2.7.0a0
torchdiffeq 0.2.5
torchmetrics 1.7.1
torchprofile 0.0.4
torchsde 0.2.6
torchvision 0.22.0a0
torchx 0.7.0
Additional context 4 A100 gpu machine