Adapt.jl icon indicating copy to clipboard operation
Adapt.jl copied to clipboard

Use with multiple wrappers

Open baggepinnen opened this issue 6 years ago • 3 comments

using CuArrays
CuArrays.allowscalar(false)
cu(reshape(cu(randn(2,2))', 1,4)) # fails at displaying
 CuArray(reshape(cu(randn(2,2))', 1,4)) # fails at the outermost call to CuArray

Details on Julia:

julia> versioninfo()
Julia Version 1.3.0-rc5.1
Commit 36c4eb251e (2019-11-17 19:04 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: Intel(R) Core(TM) i7-8700 CPU @ 3.20GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, skylake)
Environment:
  JULIA_NUM_THREADS = 6

  [3895d2a7] CUDAapi v2.0.0
  [c5f51814] CUDAdrv v4.0.4
  [be33ccc6] CUDAnative v2.5.5
  [3a865a2d] CuArrays v1.4.7

I hit this error all the time while backpropagating using Tracker, not sure which adjoint definition causes this error though.

Both the adjoint and the reshape are required for the error to appear, either one of them by themselves works alright.

baggepinnen avatar Nov 22 '19 03:11 baggepinnen

This method definition gets around the problem, not sure if it's a hack though

function CuArray(x::Base.ReshapedArray{<:Any, <:Any, <:Adjoint})
    xp = CuArray(x.parent)
    ra = Base.ReshapedArray(xp,x.dims, x.mi)
    CuArray(ra)
end

baggepinnen avatar Nov 22 '19 03:11 baggepinnen

Another instance from slack, where a Reshape of a Transpose doesn't broadcast properly. Workaround:

julia> GPUArrays.BroadcastStyle(::Type{<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Transpose{<:Any,AT},<:Any}}) where {AT<:GPUArray} = GPUArrays.BroadcastStyle(AT)
julia> GPUArrays.backend(::Type{<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Transpose{<:Any,AT},<:Any}}) where {AT<:GPUArray} = GPUArrays.backend(AT)
julia> vec(transpose(cu(rand(2,2)))) .+ 1
4-element CuArray{Float32,1}:
 1.974818 
 1.2322885
 1.6275826
 1.213689 

IIUC we might need something like https://github.com/JuliaLang/julia/pull/31563 to deal with this in a more profound way?

maleadt avatar Nov 22 '19 05:11 maleadt

Could something like this make sense? Can be extended with other known wrappers

const Wrapper = Union{Base.ReshapedArray, LinearAlgebra.Transpose, LinearAlgebra.Adjoint}
function CuArray(x::Base.ReshapedArray{<:Any, <:Any, <:Wrapper})
    xp = CuArray(x.parent)
    ra = Base.ReshapedArray(xp,x.dims, x.mi)
    CuArray(ra)
end

function CuArray(x::LinearAlgebra.Transpose{<:Any, <:Wrapper})
    xp = CuArray(x.parent)
    ra = Base.Transpose(xp)
    CuArray(ra)
end

function CuArray(x::LinearAlgebra.Adjoint{<:Any, <:Wrapper})
    xp = CuArray(x.parent)
    ra = Base.Adjoint(xp)
    CuArray(ra)
end

baggepinnen avatar Nov 22 '19 06:11 baggepinnen