AbstractDifferentiation.jl
AbstractDifferentiation.jl copied to clipboard
Handling of thunks and tangents
When the VJP is not an abstract array, things get weird
julia> import AbstractDifferentiation as AD
julia> import Zygote
julia> ad_backend = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}(Zygote.ZygoteRuleConfig{Zygote.Context{false}}(Zygote.Context{false}(nothing)))
julia> AD.second_derivative(ad_backend, identity, 1)
ERROR: MethodError: no method matching length(::ChainRulesCore.NoTangent)
julia> AD.hessian(ad_backend, sum, [1.0])
ERROR: MethodError: no method matching size(::ChainRulesCore.Thunk{ChainRulesCore.var"#48#49"{ChainRulesCore.Thunk{ChainRulesCore.var"#48#49"{…}}}})
the correct thing to do with thunks is unthunk them before using them.
The correct thing to do with NoTangent is generally to handle is specifically, or failing that to pull information from the primal. Though for NoTangent more the former, for ZeroTangent more the later (see how Diffractor turns ZeroTangent into zero_tangent.