Excessive allocations with VectorOfArray in broadcast
I am using a VectorOfArray as state vector in DifferentialEquations. This leads to too many allocations.
using DifferentialEquations
using RecursiveArrayTools
# Set up the state vector
U = VectorOfArray([zeros(100,100,100), zeros(100,100,100)]);
rhs!(U̇, U, p, t) = U̇ .= U
# Compile all code
prob = ODEProblem(rhs!, U, (0.0, 1.0));
sol = solve(prob, RK4(); adaptive=false, dt=1.0);
# Benchmark
prob = ODEProblem(rhs!, U, (0.0, 1.0));
@time sol = solve(prob, RK4(); adaptive=false, dt=1.0);
This outputs
1.639028 seconds (42.00 M allocations: 991.828 MiB, 18.67% gc time)
Note the very large number of allocations. It seems as if some operation was performing one small allocation per array element.
(I notice that this problem disappears if I switch to the midpoint rule.)
I tried with both Julia 1.7 and the current release branch of Julia 1.8.
The profiles show that it's all in broadcast. VectorOfArray just needs a better broadcast overload.
https://github.com/SciML/RecursiveArrayTools.jl/blob/v2.26.3/src/vector_of_array.jl#L288
That for some reason seems to be allocating.
Here is the offending call without the ODE solver:
using RecursiveArrayTools
u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5
@time @. u = uprev + (dt / 6) * (2 * (k₂ + k₃) + (k₁ + k₄))
It looks like it's triggered by having a scalar in the broadcast tree.
Does ArrayOfArrays.jl have a better specialization?
Switching to
@time @. u = uprev + (dt / 6.0) * (2.0 * (k₂ + k₃) + (k₁ + k₄))
Seem to speedup things quite a bit? So there might be some type instability hiding somewhere?
Edit: Rewriting the expression as follows is even better (and does not result in any type instabilities)
@time @. u = uprev + (dt/3) * ((k₂ + k₃) + (k₁ + k₄)/2)
It seems to be a problem when scalar appears and the broadcasted is too nested:
using RecursiveArrayTools
u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5
A = Broadcast.broadcasted(*, 2, Broadcast.broadcasted(+, k₂, k₃))
B = Broadcast.broadcasted(+, k₁, k₄)
C = Broadcast.broadcasted(+, A, B)
C2 = Broadcast.broadcasted(identity, C)
julia> @time Broadcast.materialize!(u, C);
0.006494 seconds
julia> @time Broadcast.materialize!(u, C2);
2.928890 seconds (60.00 M allocations: 1.103 GiB, 5.41% gc time)
and with Cthulhu it can be found the type instabilities are from base/broadcast.jl#403 and base/broadcast.jl#380:
julia> @descend Broadcast.materialize!(buf, C2)
[...]
(::Base.Broadcast.var"#23#24")(head, tail::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:403
Body::Tuple{Vararg{Float64}}
404 1 ─ %1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()), tail)::Tuple{Vararg{Float64}}
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
• %1 = call #23(::Float64,::Float64...)::Tuple{Vararg{Float64}}
[...]
(::Base.Broadcast.var"#16#18")(args::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:380
Body::Tuple{Float64, Vararg{Float64}}
381 1 ─ %1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), args)::Tuple{Float64, Float64, Vararg{Float64}}
│ %3 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %1)::Tuple{Vararg{Float64}}
│ %4 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %3)::Tuple{Vararg{Float64}} │
│ %9 = Core._apply_iterate(Base.iterate, Core.tuple, %8, %4)::Tuple{Float64, Vararg{Float64}} │
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
• %1 = call #13(::Float64,::Float64...)::Tuple{Float64, Float64, Vararg{Float64}}
%2 = call #19(::Float64,::Float64,::Float64...)::Tuple{Float64, Float64}
%3 = call #23(::Float64,::Float64,::Float64...)::Tuple{Vararg{Float64}}
%4 = call #15(::Float64...)::Tuple{Vararg{Float64}}
↩