AbstractDifferentiation.jl icon indicating copy to clipboard operation
AbstractDifferentiation.jl copied to clipboard

Fix type instabilities

Open gdalle opened this issue 2 years ago • 7 comments

One of the big downsides of AbstractDifferentiation.jl is the heavy use of (sometimes nested) closures and if-based dispatch, which generates type instability. I think many of those are fixable, in the worst case by replacing closures with callable structs.

gdalle avatar Sep 22 '23 22:09 gdalle

Did you actually check type stability? Or is it just a general comment that there could be issues?

devmotion avatar Sep 22 '23 22:09 devmotion

I switched my package ImplicitDifferentiation from ChainRulesCore + ZygoteRuleConfig to AbstractDifferentiation + Zygote and that brought type instabilities which were not there before. Not sure if it is Zygote's fault or not, but given how the code here looks I'd be very surprised if we can't improve it.

You're right though, first thing we want to do is diagnose it. I'm thinking we run JET.test_opt on each interface function for each backend and see how we fare

gdalle avatar Sep 22 '23 22:09 gdalle

Did you try the master branch? Hopefully the recent simplifications already improved things.

devmotion avatar Sep 22 '23 22:09 devmotion

Not yet, I'm calling it a day but adding some JET tests is my next order of business

gdalle avatar Sep 22 '23 22:09 gdalle

Thanks for the collaboration by the way, it's energizing and I'm learning a lot

gdalle avatar Sep 22 '23 22:09 gdalle

Can you post an MWE?

mohdibntarek avatar Mar 13 '24 11:03 mohdibntarek

Of course the type instabilities are backend-dependent, and if every function is reimplemented in a type-stable backend we won't see them. Thus I had to go for the ChainRules wrapper, which rebuilds everything from the rrule up:

julia> import AbstractDifferentiation as AD

julia> import Zygote

julia> using Test

julia> f(x::Number) = abs2(x);

julia> f(x::Array) = abs2.(x);

julia> g(x) = sum(abs2, x);

julia> ad_backend = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}(Zygote.ZygoteRuleConfig{Zygote.Context{false}}(Zygote.Context{false}(nothing)))

julia> @inferred AD.derivative(ad_backend, f, 1.0)
(2.0,)

julia> @inferred AD.second_derivative(ad_backend, f, 1.0)
ERROR: return type Tuple{Float64} does not match inferred return type Tuple
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ ~/Downloads/ad.jl:11

julia> @inferred AD.jacobian(ad_backend, f, [1.0])
ERROR: return type Tuple{Matrix{Float64}} does not match inferred return type Any
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ ~/Downloads/ad.jl:12

julia> @inferred AD.hessian(ad_backend, g, [1.0])
ERROR: return type Tuple{Matrix{Float64}} does not match inferred return type Any
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ ~/Downloads/ad.jl:13

This demonstrates type instabilities in the fallback structure of the package itself, not the extension

gdalle avatar Mar 13 '24 20:03 gdalle