DeepSpeed
DeepSpeed copied to clipboard
[ROCm] temporary workaround till __double2half support enabled in HIP
This is a temporary fix for the following error , which is encountered when running stable_diffusion inference with deepspeed inference till __double2half intrinsic support is enabled on rocm.
FAILED: gelu.cuda.o
/opt/rocm/bin/hipcc -DWITH_HIP -DTORCH_EXTENSION_NAME=transformer_inference -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -I/opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/transformer/inference/includes -I/opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/includes -isystem /opt/conda/lib/python3.8/site-packages/torch/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THC -isystem /opt/conda/lib/python3.8/site-packages/torch/include/THH -isystem /opt/rocm/include -isystem /opt/rocm/miopen/include -isystem /opt/rocm/hip/include -isystem /opt/conda/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++14 -O3 -std=c++14 -g -Wno-reorder -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -O3 -std=c++14 -U__HIP_NO_HALF_OPERATORS__ -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF2_OPERATORS__ -DROCM_VERSION_MAJOR=5 -DROCM_VERSION_MINOR=4 --amdgpu-target=gfx900 --amdgpu-target=gfx906 --amdgpu-target=gfx908 --amdgpu-target=gfx90a --amdgpu-target=gfx1030 -fno-gpu-rdc -c /opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/gelu.hip -o gelu.cuda.o
Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead.
Warning: The --amdgpu-target option has been deprecated and will be removed in the future. Use --offload-arch instead.
In file included from /opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/gelu.hip:7:
/opt/conda/lib/python3.8/site-packages/deepspeed/ops/csrc/includes/conversion_utils_hip.h:269:12: error: use of undeclared identifier '__double2half'; did you mean '__double2hiint'?
return __double2half(val);
^~~~~~~~~~~~~
__double2hiint
/opt/rocm-5.4.0/include/hip/amd_detail/amd_device_functions.h:440:30: note: '__double2hiint' declared here
__device__ static inline int __double2hiint(double x) {
^
1 error generated when compiling for gfx1030.
@jithunnair-amd @amathews-amd @jeffdaily