`finite_difference_gradient` is mutating arrays
julia> function f(x,p)
grad = FiniteDiff.finite_difference_gradient(y -> sum(y.^3), x)
return grad .* p
end
f (generic function with 1 method)
julia> x,p = rand(3),rand(3);
julia> Zygote.gradient(p->sum(f(x,p)), p)[1]
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] _throw_mutation_error(f::Function, args::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:70
[3] (::Zygote.var"#444#445"{Vector{Float64}})(#unused#::Nothing)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\lib\array.jl:82
[4] (::Zygote.var"#2496#back#446"{Zygote.var"#444#445"{Vector{Float64}}})(Δ::Nothing)
@ Zygote C:\Users\Luffy\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
[5] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:277 [inlined]
[6] (::typeof(∂(#finite_difference_gradient!#16)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[7] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:224 [inlined]
[8] (::typeof(∂(finite_difference_gradient!##kw)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[9] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:88 [inlined]
[10] (::typeof(∂(#finite_difference_gradient#12)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[11] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
[12] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[13] Pullback
@ C:\Users\Luffy\.julia\packages\FiniteDiff\KkXlb\src\gradients.jl:70 [inlined]
[14] (::typeof(∂(finite_difference_gradient)))(Δ::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[15] Pullback
@ .\REPL[94]:2 [inlined]
[16] (::typeof(∂(f)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[17] Pullback
@ .\REPL[95]:1 [inlined]
[18] (::typeof(∂(#30)))(Δ::Float64)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
[19] (::Zygote.var"#60#61"{typeof(∂(#30))})(Δ::Float64)
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
[20] gradient(f::Function, args::Vector{Float64})
@ Zygote C:\Users\Luffy\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:76
[21] top-level scope
@ REPL[95]:1
[22] top-level scope
@ C:\Users\Luffy\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
It is tricky to make this completely non-mutating for AbstractArray.
If you look how this case is handled in ForwardDiff, it is also has mutations in there, even for out of place f :
https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/gradient.jl#L106
https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L36
https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/apiutils.jl#L58
StaticArrays is then handled by a specialized dispatch.
Zygote seems to be handled by custom rules:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/forward.jl#L140
Notice how this is not really reverse over forward.
The issue is that you have a bunch of partial derivatives:
https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/src/gradients.jl#L287
But how do you gather them into an array of the correct type without mutating?
The correct type being the array type of the input of f with the element type changed to the type of the output of f?
If you use out of place setindex it should be fine?
I have not used setindex much, but I thought that was only defined for Tuple and StaticArray, not for Array?
Are there things in ArrayInterface that extends this?
FiniteDiff.jl has an extended version (shadowed)