ReverseDiff.jl icon indicating copy to clipboard operation
ReverseDiff.jl copied to clipboard

ChainRulesCore projection & ReverseDiff.TrackedArray

Open mcabbott opened this issue 3 years ago • 0 comments

When using ReverseDiff over Zygote, its tracked arrays will often be turned into arrays of TrackedReal:

julia> using ReverseDiff, Zygote

julia> _, back = pullback(x -> cumsum(x .^ 3), rand(3))
([0.0417770851525806, 0.08898338554941161, 0.13448629430576223], Zygote.var"#52#53"{typeof(∂(#11))}(∂(#11)))

julia> ta = ReverseDiff.track([1,2,3.0]);

julia> back(ta)[1]
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}:
 TrackedReal<Ha5>(2.1671946936904316, 0.0, BcR, ---)
 TrackedReal<J2r>(1.9592562290038964, 0.0, BcR, ---)
 TrackedReal<5TB>(1.147101768209455, 0.0, BcR, ---)

The cause of this is that ChainRulesCore's projection mechanism looks at the eltype, and decides that it needs correcting:

julia> using ChainRulesCore

julia> pt = ProjectTo(ta)
ProjectTo{AbstractArray}(element = ProjectTo{Real}(), axes = (Base.OneTo(3),))

julia> pt(ta)  # this works fine, TrackedArray
3-element ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}:
 TrackedReal<EIm>(1.0, 0.0, BcR, 1, 96O)
 TrackedReal<I8K>(2.0, 0.0, BcR, 2, 96O)
 TrackedReal<9lo>(3.0, 0.0, BcR, 3, 96O)

julia> pr = ProjectTo(rand(3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),))

julia> pr(ta)  # makes a Vector{TrackedReal}
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}:
 TrackedReal<KBA>(1.0, 0.0, BcR, 1, 96O)
 TrackedReal<29L>(2.0, 0.0, BcR, 2, 96O)
 TrackedReal<7oX>(3.0, 0.0, BcR, 3, 96O)

julia> map(pr.element, ta)  # this is what projection calls, as TrackedReal <: Float64 is false
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}:
 TrackedReal<AQQ>(1.0, 0.0, BcR, 1, 96O)
 TrackedReal<5IS>(2.0, 0.0, BcR, 2, 96O)
 TrackedReal<2yC>(3.0, 0.0, BcR, 3, 96O)

I think this should probably be avoided, by adding methods to (::ProjectTo)(::TrackedArray) which bypass this behaviour. Since ReverseDiff already depends on CRC, no extra deps are required.

mcabbott avatar Jul 01 '22 22:07 mcabbott