TaylorDiff.jl icon indicating copy to clipboard operation
TaylorDiff.jl copied to clipboard

Zygote compatibility does not work for Julia 1.10+

Open mrazomej opened this issue 1 year ago • 5 comments

I brought up this as an issue in the Zygote.jl repository, but it might belong here:

Zygote fails to use rrules defined by TaylorDiff when run with Julia 1.10+

In Julia 1.9.4:

import Zygote
import TaylorDiff

TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works

Zygote.withgradient([1.0, 2.0, 3.0]) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # works, returning (val = 4.0, grad = ([0.0, 2.0, 0.0],))

In Julia 1.10+:

import Zygote
import TaylorDiff

TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works

Zygote.withgradient([1.0, 2.0, 3.0]) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # doesn't work

The last line gives the following error:

ERROR: Need an adjoint for constructor TaylorDiff.TaylorScalar{Float64, 2}. Gradient is of type TaylorDiff.TaylorScalar{Float64, 2}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{TaylorDiff.TaylorScalar{Float64, 2}, Nothing, false})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:330
  [3] (::Zygote.var"#2210#back#313"{Zygote.Jnew{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [4] TaylorScalar
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:17 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [6] TaylorScalar
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:22 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:143 [inlined]
  [8] ^
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:128 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [10] literal_pow
    @ ./intfuncs.jl:351 [inlined]
 [11] (::Zygote.var"#1368#1374")(::Tuple{…}, ȳ₁::TaylorDiff.TaylorScalar{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
 [12] #4
    @ ./generator.jl:36 [inlined]
 [13] iterate(g::Base.Generator, s::Vararg{Any})
    @ Base ./generator.jl:47 [inlined]
 [14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Base.var"#4#5"{Zygote.var"#1368#1374"}})
    @ Base ./array.jl:834
 [15] map
    @ ./abstractarray.jl:3406 [inlined]
 [16] (::Zygote.var"#∇broadcasted#1373"{…})(ȳ::FillArrays.Fill{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
 [17] #4117#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [18] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [19] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [20] broadcasted
    @ ./broadcast.jl:1347 [inlined]
 [21] #8
    @ ./REPL[4]:2 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [23] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:37 [inlined]
 [24] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:23 [inlined]
 [25] #7
    @ ./REPL[4]:2 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [28] withgradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:0
 [29] top-level scope
    @ REPL[4]:1

mrazomej avatar Feb 22 '24 02:02 mrazomej

I confirm this issue. Have you found a workaround?

mBarreau avatar Feb 23 '24 15:02 mBarreau

Sorry haven't got a chance to look at the breaking changes of v1.10. It looks strange since rrules are just some function overloading, shouldn't depend too much on language core...

tansongchen avatar Feb 23 '24 17:02 tansongchen

@ToucheSir Could you take a look?

YichengDWu avatar Feb 24 '24 17:02 YichengDWu

@mrazomej when cross-posting issues, please link back to the original so that readers have some context. In this case, there's plenty of background in https://github.com/FluxML/Zygote.jl/issues/1502. We also have a Slack discussion in the #autodiff channel about ideas to fix this. It may be that changes made in the Julia compiler need to be reverted.

ToucheSir avatar Feb 24 '24 18:02 ToucheSir

Hi all,

If we want to solve this issue, we will need to put more effort into it. From the slack discussion, it appears that we need to provide the simplest MWE possible. If we use Zygote.gradients on Taylor diff.derivative, we indeed get the error. How can we delete the dependency on TaylorDiff? I don't understand well enough the package to do this. @tansongchen , can you do that?

Then we can post an issue in Zygote.

mBarreau avatar Mar 11 '24 07:03 mBarreau

Hi @mBarreau ,

I haven't been working on this for a while due to other obligations at school, but now I'm back and willing to solve this compatibility issue.

I made a little experiment and it seems that I can still get Zygote-over-TaylorDiff work by using Zygote's own rule definition tools, @adjoint, so it isn't necessary to hit the Julia base or Zygote to solve this. I will post a fix soon.

tansongchen avatar Sep 25 '24 20:09 tansongchen

Nevertheless, it still makes sense to submit an issue to Julia base or Zygote since we want using ChainRulesCore at the end of day. So I'm working on an MWE:

using Zygote
import ChainRulesCore: rrule

struct MyNumber{T}
    x::T
end

function rrule(::Type{MyNumber{T}}, v::T) where T
    return MyNumber{T}(v), (result) -> result.x
end

square(s::MyNumber{T}) where T = MyNumber{T}(s.x * s.x)

but it didn't reproduce the issue

julia> y, back = Zygote._pullback(x -> square(MyNumber(x)), 1.0)
(MyNumber{Float64}(1.0), ∂(#21))

and I don't know why...

tansongchen avatar Sep 25 '24 20:09 tansongchen

What about @noinline on the TaylorScalar{T, N} constructor?

ChrisRackauckas avatar Sep 26 '24 10:09 ChrisRackauckas

Fixed by https://github.com/JuliaDiff/TaylorDiff.jl/commit/c11dd72ce8dc2eaf1ffc2def171edd50ebe72632

tansongchen avatar Sep 26 '24 19:09 tansongchen

Note, there has been a significant change under the hood: now TaylorScalar <: Real, so I would expect some correctness problem when using it together with Zygote due to lacking @opt_outs, so use with caution

tansongchen avatar Sep 26 '24 19:09 tansongchen

@tansongchen and @mBarreau, I don't think this issue has been resolved; although it does not return an error in julia 1.10.5, it does not give the correct answer.

on the CPU:

import Zygote
import TaylorDiff

TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works

Zygote.withgradient([1.0, 2.0, 3.0]) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # doesn't work, returns (val = 4.0, grad = ([0.0, 0.0, 0.0],))

Moreover, doing this with CUDA arrays gives something like this:

import Zygote
import TaylorDiff
import CUDA

TaylorDiff.derivative(x -> sum(x .^ 2), CUDA.cu([1.0, 2.0, 3.0]), CUDA.cu([0.0, 1.0, 0.0]), :1) # works

Zygote.withgradient(CUDA.cu([1.0, 2.0, 3.0])) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, CUDA.cu([0.0, 1.0, 0.0]), :1)
end # doesn't work, returns (val = Dual{Nothing}(4.0,2.0,12.0,0.0), grad = (nothing,))

mrazomej avatar Oct 04 '24 20:10 mrazomej

I confirm that the same issue persists for Julia 1.11.0

mrazomej avatar Oct 08 '24 22:10 mrazomej