SliceMap.jl icon indicating copy to clipboard operation
SliceMap.jl copied to clipboard

slicemap on TrackedArrays produces Array{TrackedReal}

Open baggepinnen opened this issue 6 years ago • 3 comments

This is a problem since the code below will not run on the GPU (unless one allows scalar operations which is not ideal)

julia> using SliceMap, Flux

julia> slicemap(norm, Flux.param(randn(2,2,2,2)), dims=(1,2))
2×2 Array{Tracker.TrackedReal{Float64},2}:
 1.91925  2.69252
 1.26406  1.22966

baggepinnen avatar Nov 21 '19 04:11 baggepinnen

Oops, I see now that this function is really only supported for Zygote. Feel free to close this issue if appropriate.

baggepinnen avatar Nov 21 '19 04:11 baggepinnen

That’s right, I don’t recally exactly why but I couldn’t make the gradients work in general for Tracker. (For mapcols I explicitly handle the pullback functions myself, rather than getting Tracker to keep track of them.)

The simplest version of this was the @grad function gluecol here: https://github.com/mcabbott/SliceMap.jl/blob/master/src/SliceMap.jl#L197 Since norm returns a scalar, so you don’t actually need to glue anything, but the slicing stage also seems to have problems.

mcabbott avatar Nov 21 '19 10:11 mcabbott

However you can do this:

julia> reshape(mapcols(norm, reshape(Tracker.param(randn(2,2,2,3)), 4,6)), 2,3)
Tracked 2×3 Array{Float64,2}:
 1.2732   1.34758  2.0512 
 1.42643  1.55332  1.70304

I guess anything with dims=1:M can be handled like this, norm(vec(m)) == norm(m) but for other functions you could insert reshape there too.

mcabbott avatar Nov 21 '19 11:11 mcabbott