FiniteDiff.jl icon indicating copy to clipboard operation
FiniteDiff.jl copied to clipboard

`finite_difference_gradient` is mutating arrays

Open YichengDWu opened this issue 3 years ago • 4 comments

julia> function f(x,p)
           grad = FiniteDiff.finite_difference_gradient(y -> sum(y.^3), x)
           return grad .* p
       end
f (generic function with 1 method)

julia> x,p = rand(3),rand(3);

julia> Zygote.gradient(p->sum(f(x,p)), p)[1]
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
  [3] (::Zygote.var"#444#445"{Vector{Float64}})(#unused#::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:82
  [4] (::Zygote.var"#2496#back#446"{Zygote.var"#444#445"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [5] Pullback
    @ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:277 [inlined]
  [6] (::typeof(∂(#finite_difference_gradient!#16)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [7] Pullback
    @ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:224 [inlined]
  [8] (::typeof(∂(finite_difference_gradient!##kw)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
  [9] Pullback
    @ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:88 [inlined]
 [10] (::typeof(∂(#finite_difference_gradient#12)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [11] Pullback
    @ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
 [12] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [13] Pullback
    @ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
 [14] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [15] Pullback
    @ .\REPL[94]:2 [inlined]
 [16] (::typeof(∂(f)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [17] Pullback
    @ .\REPL[95]:1 [inlined]
 [18] (::typeof(∂(#30)))(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [19] (::Zygote.var"#60#61"{typeof(∂(#30))})(Δ::Float64)
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [20] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
 [21] top-level scope
    @ REPL[95]:1
 [22] top-level scope
    @ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52

YichengDWu avatar Jul 27 '22 07:07 YichengDWu

It is tricky to make this completely non-mutating for AbstractArray.

If you look how this case is handled in ForwardDiff, it is also has mutations in there, even for out of place f : https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/gradient.jl#L106 https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L36 https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L58

StaticArrays is then handled by a specialized dispatch.

Zygote seems to be handled by custom rules: https://github.com/FluxML/Zygote.jl/blob/master/src/lib/forward.jl#L140 Notice how this is not really reverse over forward.

The issue is that you have a bunch of partial derivatives: https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/src/gradients.jl#L287 But how do you gather them into an array of the correct type without mutating? The correct type being the array type of the input of f with the element type changed to the type of the output of f?

ArnoStrouwen avatar Nov 11 '22 01:11 ArnoStrouwen

If you use out of place setindex it should be fine?

ChrisRackauckas avatar Nov 11 '22 02:11 ChrisRackauckas

I have not used setindex much, but I thought that was only defined for Tuple and StaticArray, not for Array? Are there things in ArrayInterface that extends this?

ArnoStrouwen avatar Nov 11 '22 02:11 ArnoStrouwen

FiniteDiff.jl has an extended version (shadowed)

ChrisRackauckas avatar Nov 11 '22 02:11 ChrisRackauckas