ReverseDiff.jl
ReverseDiff.jl copied to clipboard
ChainRulesCore projection & ReverseDiff.TrackedArray
When using ReverseDiff over Zygote, its tracked arrays will often be turned into arrays of TrackedReal:
julia> using ReverseDiff, Zygote
julia> _, back = pullback(x -> cumsum(x .^ 3), rand(3))
([0.0417770851525806, 0.08898338554941161, 0.13448629430576223], Zygote.var"#52#53"{typeof(∂(#11))}(∂(#11)))
julia> ta = ReverseDiff.track([1,2,3.0]);
julia> back(ta)[1]
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}:
TrackedReal<Ha5>(2.1671946936904316, 0.0, BcR, ---)
TrackedReal<J2r>(1.9592562290038964, 0.0, BcR, ---)
TrackedReal<5TB>(1.147101768209455, 0.0, BcR, ---)
The cause of this is that ChainRulesCore's projection mechanism looks at the eltype, and decides that it needs correcting:
julia> using ChainRulesCore
julia> pt = ProjectTo(ta)
ProjectTo{AbstractArray}(element = ProjectTo{Real}(), axes = (Base.OneTo(3),))
julia> pt(ta) # this works fine, TrackedArray
3-element ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}:
TrackedReal<EIm>(1.0, 0.0, BcR, 1, 96O)
TrackedReal<I8K>(2.0, 0.0, BcR, 2, 96O)
TrackedReal<9lo>(3.0, 0.0, BcR, 3, 96O)
julia> pr = ProjectTo(rand(3))
ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),))
julia> pr(ta) # makes a Vector{TrackedReal}
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}:
TrackedReal<KBA>(1.0, 0.0, BcR, 1, 96O)
TrackedReal<29L>(2.0, 0.0, BcR, 2, 96O)
TrackedReal<7oX>(3.0, 0.0, BcR, 3, 96O)
julia> map(pr.element, ta) # this is what projection calls, as TrackedReal <: Float64 is false
3-element Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}:
TrackedReal<AQQ>(1.0, 0.0, BcR, 1, 96O)
TrackedReal<5IS>(2.0, 0.0, BcR, 2, 96O)
TrackedReal<2yC>(3.0, 0.0, BcR, 3, 96O)
I think this should probably be avoided, by adding methods to (::ProjectTo)(::TrackedArray) which bypass this behaviour. Since ReverseDiff already depends on CRC, no extra deps are required.