MatX
MatX copied to clipboard
[BUG] Warning Message for Host call in device function for Einsum
Describe the Bug
When using the einsum API, a compile warning is made warning that a host function is called from a host/device function, which is not allowed.
I have verified this is not an issue, so I'm not sure if we are only calling this from the host, or if the warning about a host call being used on the device is incorrect.
below is the warning:
[ 50%] Building CUDA object CMakeFiles/einsumTest.dir/einsumTest.cu.o
/scratch/playground/matxExample/MatX/include/matx/core/tensor_impl.h(657): warning #20011-D: calling a __host__ function("std::__cxx11::basic_string<char, ::std::char_traits<char> , ::std::allocator<char> > ::basic_string(const ::std::__cxx11::basic_string<char, ::std::char_traits<char> , ::std::allocator<char> > &)") from a __host__ __device__ function("matx::detail::EinsumOp< ::matx::tensor_t<float, (int)3, ::matx::basic_storage< ::matx::raw_pointer_buffer<float, ::matx::matx_allocator<float> > > , ::matx::tensor_desc_t< ::cuda::std::__4::array<long long, (unsigned long)3ul> , ::cuda::std::__4::array<long long, (unsigned long)3ul> , (int)3> > , ::matx::tensor_t<float, (int)3, ::matx::basic_storage< ::matx::raw_pointer_buffer<float, ::matx::matx_allocator<float> > > , ::matx::tensor_desc_t< ::cuda::std::__4::array<long long, (unsigned long)3ul> , ::cuda::std::__4::array<long long, (unsigned long)3ul> , (int)3> > > ::EinsumOp") is not allowed
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
[100%] Linking CUDA executable einsumTest
To Reproduce
make any einsum call
Expected Behavior
No warning is presented
Code Snippets
auto a = matx::make_tensor<float>({3,1,1});
auto b = matx::make_tensor<float>({1,3,2});
auto c = matx::make_tensor<float>({1,2});
a.SetVals({{{1,2,3}}});
b.SetVals({{{1,2},{3,4},{5,6}}});
(c = matx::cutensor::einsum("ijk,jil->kl", a, b)).run();
matx::print(c);
System Details (please complete the following information):
- OS: Ubuntu 22.04
- CUDA version: 12.5
- g++ version: 11.4