[BUG]: `LowLevelZeroOptimizer` gets stuck in `all_reduce` when executing `check_overflow`
🐛 Describe the bug
I'm having a problem running the code of the ColossalChat: the strategy.optimizer_step gets stuck, and specifically, it's stuck in the execution of torch.all_reduce. I'm running the code with CUDA 12.1.105 on RTX 4090 GPUs, which requiring torch 2+. The same code is running normally with CUDA 11.7 and torch 1.13 on RTX 3090 GPUs.
Environment
Python==3.9.18 installed packages:
Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
absl-py 2.0.0 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
annotated-types 0.6.0 pypi_0 pypi
anyio 3.7.1 pypi_0 pypi
appdirs 1.4.4 pypi_0 pypi
async-timeout 4.0.3 pypi_0 pypi
attrs 23.1.0 pypi_0 pypi
beautifulsoup4 4.12.2 pypi_0 pypi
blas 1.0 mkl
ca-certificates 2023.08.22 h06a4308_0
cachetools 5.3.1 pypi_0 pypi
certifi 2023.7.22 pypi_0 pypi
cffi 1.16.0 pypi_0 pypi
charset-normalizer 3.3.1 pypi_0 pypi
click 8.1.7 pypi_0 pypi
colossalai 0.3.3 pypi_0 pypi
cryptography 41.0.4 pypi_0 pypi
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.3.52 0 nvidia
cuda-runtime 12.1.0 0 nvidia
decorator 5.1.1 pypi_0 pypi
deprecated 1.2.14 pypi_0 pypi
exceptiongroup 1.1.3 pypi_0 pypi
fastapi 0.104.0 pypi_0 pypi
filelock 3.9.0 py39h06a4308_0
flash-attn 2.0.5 pypi_0 pypi
frozenlist 1.4.0 pypi_0 pypi
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py39heeb90bb_0
google 3.0.0 pypi_0 pypi
google-auth 2.23.3 pypi_0 pypi
google-auth-oauthlib 1.0.0 pypi_0 pypi
grpcio 1.59.0 pypi_0 pypi
idna 3.4 pypi_0 pypi
importlib-metadata 6.8.0 pypi_0 pypi
intel-openmp 2023.1.0 hdb19cb5_46305
jinja2 3.1.2 py39h06a4308_0
ld_impl_linux-64 2.38 h1181459_1
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.8.0.34 0 nvidia
libcurand 10.3.4.52 0 nvidia
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libnpp 12.0.2.50 0 nvidia
libnvjitlink 12.1.105 0 nvidia
libnvjpeg 12.1.1.14 0 nvidia
libstdcxx-ng 11.2.0 h1234567_1
llvm-openmp 14.0.6 h9e868ea_0
markdown 3.5 pypi_0 pypi
markdown-it-py 3.0.0 pypi_0 pypi
markupsafe 2.1.3 pypi_0 pypi
mdurl 0.1.2 pypi_0 pypi
mkl 2023.1.0 h213fc3f_46343
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py39h06a4308_0
multidict 6.0.4 pypi_0 pypi
mypy-extensions 1.0.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.2 pypi_0 pypi
ninja 1.11.1.1 pypi_0 pypi
numpy 1.26.1 pypi_0 pypi
nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
nvidia-nccl-cu12 2.18.1 pypi_0 pypi
nvidia-nvjitlink-cu12 12.3.52 pypi_0 pypi
nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
oauthlib 3.2.2 pypi_0 pypi
openssl 3.0.11 h7f8727e_2
packaging 23.2 pypi_0 pypi
pandas 2.1.1 pypi_0 pypi
pip 23.3 py39h06a4308_0
protobuf 4.24.4 pypi_0 pypi
psutil 5.9.6 pypi_0 pypi
pyasn1 0.5.0 pypi_0 pypi
pyasn1-modules 0.3.0 pypi_0 pypi
pycparser 2.21 pypi_0 pypi
pydantic 2.4.2 pypi_0 pypi
pydantic-core 2.10.1 pypi_0 pypi
pygments 2.16.1 pypi_0 pypi
python 3.9.18 h955ad1f_0
python-dateutil 2.8.2 pypi_0 pypi
pytorch 2.1.0 py3.9_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pytz 2023.3.post1 pypi_0 pypi
pyyaml 6.0.1 py39h5eee18b_0
readline 8.2 h5eee18b_0
requests 2.31.0 pypi_0 pypi
requests-oauthlib 1.3.1 pypi_0 pypi
rich 13.6.0 pypi_0 pypi
rsa 4.9 pypi_0 pypi
setuptools 68.0.0 py39h06a4308_0
six 1.16.0 pypi_0 pypi
sniffio 1.3.0 pypi_0 pypi
soupsieve 2.5 pypi_0 pypi
sqlite 3.41.2 h5eee18b_0
starlette 0.27.0 pypi_0 pypi
sympy 1.12 pypi_0 pypi
tbb 2021.8.0 hdb19cb5_0
tenacity 8.2.3 pypi_0 pypi
tensorboard 2.14.0 pypi_0 pypi
tensorboard-data-server 0.7.2 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
torchtriton 2.1.0 py39 pytorch
tqdm 4.66.1 pypi_0 pypi
triton 2.1.0 pypi_0 pypi
typing-extensions 4.8.0 pypi_0 pypi
typing_extensions 4.7.1 py39h06a4308_0
tzdata 2023.3 pypi_0 pypi
ujson 5.8.0 pypi_0 pypi
wcwidth 0.2.8 pypi_0 pypi
werkzeug 3.0.0 pypi_0 pypi
wheel 0.41.2 py39h06a4308_0
wrapt 1.15.0 pypi_0 pypi
xformers 0.0.22.post4 pypi_0 pypi
xxhash 3.4.1 pypi_0 pypi
xz 5.4.2 h5eee18b_0
yaml 0.2.5 h7b6447c_0
yarl 1.9.2 pypi_0 pypi
zipp 3.17.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
We don't support torch>=2.0. And since you are using CUDA 12.1 and pytorch 2.1.0, my suggestion will be downgrade your CUDA driver and pytorch.