torch icon indicating copy to clipboard operation
torch copied to clipboard

torch_bernoulli() is failing on device="cuda" with Error : Expected a 'cuda' device type for generator but found 'cpu'

Open cregouby opened this issue 4 years ago • 1 comments

Hello,

torch_bernoulli() distribution is failing on CUDA device with the following error message

Error in (function (self, generator)  : 
  Expected a 'cuda' device type for generator but found 'cpu'

whereas it is working as expected on CPU device.

Workaround

Even if it is non-sense with regards to performance, a workaround is to compute the bernoulli in CPU before moving the result to GPU :

library(torch)
cuda_is_available()
#> [1] TRUE
# workaround
torch_bernoulli(torch_ones(3,3, device="cpu") * 0.5)$to(device="cuda")
#> torch_tensor
#>  1  0  1
#>  0  0  1
#>  0  1  0
#> [ CUDAFloatType{3,3} ]

Reprex

library(torch)
cuda_is_available()
#> [1] TRUE
# working function on device="cpu"
torch_bernoulli(torch_ones(3,3, device="cpu") * 0.5)
#> torch_tensor
#>  1  0  0
#>  0  1  0
#>  0  1  1
#> [ CPUFloatType{3,3} ]
# failing function on device = "cuda"
torch_bernoulli(torch_ones(3,3, device="cuda")* 0.5)
#> Error in (function (self, generator) : Expected a 'cuda' device type for generator but found 'cpu'
#> Exception raised from check_generator at /pytorch/aten/src/ATen/Utils.h:108 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f1b17435b89 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libc10.so)
#> frame #1: <unknown function> + 0x1e2179e (0x7f1acc08479e in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #2: <unknown function> + 0x1e21acb (0x7f1acc084acb in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #3: at::native::bernoulli_tensor_kernel(at::Tensor&, at::Tensor const&, c10::optional<at::Generator>) + 0x22 (0x7f1acc084d12 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #4: <unknown function> + 0xd8d684 (0x7f1b071ff684 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #5: <unknown function> + 0xd8db21 (0x7f1b071ffb21 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #6: at::native::bernoulli_(at::Tensor&, at::Tensor const&, c10::optional<at::Generator>) + 0x32 (0x7f1b071f4702 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #7: <unknown function> + 0x32747f3 (0x7f1acd4d77f3 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #8: <unknown function> + 0x32a9ace (0x7f1acd50cace in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cuda.so)
#> frame #9: <unknown function> + 0x14e2b63 (0x7f1b07954b63 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #10: at::Tensor::bernoulli_(at::Tensor const&, c10::optional<at::Generator>) const + 0xf2 (0x7f1b07ac0b12 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #11: at::native::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0x94 (0x7f1b071e4494 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #12: <unknown function> + 0x15b7271 (0x7f1b07a29271 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #13: <unknown function> + 0xb646d1 (0x7f1b06fd66d1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #14: <unknown function> + 0x14e24a1 (0x7f1b079544a1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #15: at::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0xda (0x7f1b0785571a in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #16: <unknown function> + 0x2a10704 (0x7f1b08e82704 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #17: <unknown function> + 0xb646d1 (0x7f1b06fd66d1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #18: <unknown function> + 0x14e24a1 (0x7f1b079544a1 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #19: at::bernoulli(at::Tensor const&, c10::optional<at::Generator>) + 0xda (0x7f1b0785571a in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
#> frame #20: _lantern_bernoulli_tensor_generator + 0x6b (0x7f1b179059db in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/liblantern.so)
#> frame #21: cpp_torch_namespace_bernoulli_self_Tensor(XPtrTorchTensor, XPtrTorchGenerator) + 0x35 (0x7f1b18158365 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/libs/torchpkg.so)
#> frame #22: _torch_cpp_torch_namespace_bernoulli_self_Tensor + 0x97 (0x7f1b17fb3ef7 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/libs/torchpkg.so)
#> frame #23: <unknown function> + 0xf5cbc (0x7f1b2204bcbc in /usr/lib/R/lib/libR.so)
#> frame #24: <unknown function> + 0xf6206 (0x7f1b2204c206 in /usr/lib/R/lib/libR.so)
#> frame #25: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #26: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #27: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #28: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #29: Rf_eval + 0x2af (0x7f1b220a11ef in /usr/lib/R/lib/libR.so)
#> frame #30: <unknown function> + 0xc16ad (0x7f1b220176ad in /usr/lib/R/lib/libR.so)
#> frame #31: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #32: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #33: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #34: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #35: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #36: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #37: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #38: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #39: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #40: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #41: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #42: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #43: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #44: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #45: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #46: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #47: Rf_eval + 0x2af (0x7f1b220a11ef in /usr/lib/R/lib/libR.so)
#> frame #48: <unknown function> + 0x151172 (0x7f1b220a7172 in /usr/lib/R/lib/libR.so)
#> frame #49: <unknown function> + 0x1303c1 (0x7f1b220863c1 in /usr/lib/R/lib/libR.so)
#> frame #50: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #51: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #52: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #53: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)
#> frame #54: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #55: <unknown function> + 0x14ba5c (0x7f1b220a1a5c in /usr/lib/R/lib/libR.so)
#> frame #56: Rf_eval + 0x39f (0x7f1b220a12df in /usr/lib/R/lib/libR.so)
#> frame #57: <unknown function> + 0x151bf0 (0x7f1b220a7bf0 in /usr/lib/R/lib/libR.so)
#> frame #58: <unknown function> + 0x18f35f (0x7f1b220e535f in /usr/lib/R/lib/libR.so)
#> frame #59: <unknown function> + 0x1301b1 (0x7f1b220861b1 in /usr/lib/R/lib/libR.so)
#> frame #60: Rf_eval + 0x88 (0x7f1b220a0fc8 in /usr/lib/R/lib/libR.so)
#> frame #61: <unknown function> + 0x14ce8f (0x7f1b220a2e8f in /usr/lib/R/lib/libR.so)
#> frame #62: Rf_applyClosure + 0x1a2 (0x7f1b220a3d82 in /usr/lib/R/lib/libR.so)
#> frame #63: <unknown function> + 0x13a6fe (0x7f1b220906fe in /usr/lib/R/lib/libR.so)

Created on 2021-02-22 by the reprex package (v1.0.0)

cregouby avatar Feb 22 '21 11:02 cregouby

Thanks @cregouby . This is indeed a bug. I'll work on a fix soon

dfalbel avatar Feb 24 '21 12:02 dfalbel

Many thanks for this ! I'll be able to remove the workaround to it !

cregouby avatar Oct 15 '22 12:10 cregouby