slicemap on TrackedArrays produces Array{TrackedReal}
This is a problem since the code below will not run on the GPU (unless one allows scalar operations which is not ideal)
julia> using SliceMap, Flux
julia> slicemap(norm, Flux.param(randn(2,2,2,2)), dims=(1,2))
2×2 Array{Tracker.TrackedReal{Float64},2}:
1.91925 2.69252
1.26406 1.22966
Oops, I see now that this function is really only supported for Zygote. Feel free to close this issue if appropriate.
That’s right, I don’t recally exactly why but I couldn’t make the gradients work in general for Tracker. (For mapcols I explicitly handle the pullback functions myself, rather than getting Tracker to keep track of them.)
The simplest version of this was the @grad function gluecol here:
https://github.com/mcabbott/SliceMap.jl/blob/master/src/SliceMap.jl#L197
Since norm returns a scalar, so you don’t actually need to glue anything, but the slicing stage also seems to have problems.
However you can do this:
julia> reshape(mapcols(norm, reshape(Tracker.param(randn(2,2,2,3)), 4,6)), 2,3)
Tracked 2×3 Array{Float64,2}:
1.2732 1.34758 2.0512
1.42643 1.55332 1.70304
I guess anything with dims=1:M can be handled like this, norm(vec(m)) == norm(m) but for other functions you could insert reshape there too.