RecursiveArrayTools.jl icon indicating copy to clipboard operation
RecursiveArrayTools.jl copied to clipboard

Excessive allocations with VectorOfArray in broadcast

Open eschnett opened this issue 3 years ago • 5 comments

I am using a VectorOfArray as state vector in DifferentialEquations. This leads to too many allocations.

using DifferentialEquations
using RecursiveArrayTools

# Set up the state vector
U = VectorOfArray([zeros(100,100,100), zeros(100,100,100)]);
rhs!(U̇, U, p, t) = U̇ .= U

# Compile all code
prob = ODEProblem(rhs!, U, (0.0, 1.0));
sol = solve(prob, RK4(); adaptive=false, dt=1.0);

# Benchmark
prob = ODEProblem(rhs!, U, (0.0, 1.0));
@time sol = solve(prob, RK4(); adaptive=false, dt=1.0);

This outputs

  1.639028 seconds (42.00 M allocations: 991.828 MiB, 18.67% gc time)

Note the very large number of allocations. It seems as if some operation was performing one small allocation per array element.

(I notice that this problem disappears if I switch to the midpoint rule.)

I tried with both Julia 1.7 and the current release branch of Julia 1.8.

eschnett avatar May 15 '22 00:05 eschnett

The profiles show that it's all in broadcast. VectorOfArray just needs a better broadcast overload.

https://github.com/SciML/RecursiveArrayTools.jl/blob/v2.26.3/src/vector_of_array.jl#L288

That for some reason seems to be allocating.

ChrisRackauckas avatar May 15 '22 12:05 ChrisRackauckas

Here is the offending call without the ODE solver:

using RecursiveArrayTools

u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5

@time @. u = uprev + (dt / 6) * (2 * (k₂ + k₃) + (k₁ + k₄))

It looks like it's triggered by having a scalar in the broadcast tree.

ChrisRackauckas avatar May 15 '22 12:05 ChrisRackauckas

Does ArrayOfArrays.jl have a better specialization?

Moelf avatar May 15 '22 13:05 Moelf

Switching to

@time @. u = uprev + (dt / 6.0) * (2.0 * (k₂ + k₃) + (k₁ + k₄))

Seem to speedup things quite a bit? So there might be some type instability hiding somewhere?

Edit: Rewriting the expression as follows is even better (and does not result in any type instabilities)

@time @. u = uprev + (dt/3) * ((k₂ + k₃) + (k₁ + k₄)/2)

mipals avatar May 15 '22 16:05 mipals

It seems to be a problem when scalar appears and the broadcasted is too nested:

using RecursiveArrayTools

u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5

A = Broadcast.broadcasted(*, 2, Broadcast.broadcasted(+, k₂, k₃))
B = Broadcast.broadcasted(+, k₁, k₄)
C = Broadcast.broadcasted(+, A, B)
C2 = Broadcast.broadcasted(identity, C)

julia> @time Broadcast.materialize!(u, C);
  0.006494 seconds

julia> @time Broadcast.materialize!(u, C2);
  2.928890 seconds (60.00 M allocations: 1.103 GiB, 5.41% gc time)

and with Cthulhu it can be found the type instabilities are from base/broadcast.jl#403 and base/broadcast.jl#380:

julia> @descend Broadcast.materialize!(buf, C2)
[...]
(::Base.Broadcast.var"#23#24")(head, tail::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:403
Body::Tuple{Vararg{Float64}}
404 1 ─ %1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()), tail)::Tuple{Vararg{Float64}}
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
 • %1  = call #23(::Float64,::Float64...)::Tuple{Vararg{Float64}}

[...]

(::Base.Broadcast.var"#16#18")(args::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:380
Body::Tuple{Float64, Vararg{Float64}}
381 1 ─ %1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), args)::Tuple{Float64, Float64, Vararg{Float64}}
    │   %3 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %1)::Tuple{Vararg{Float64}}
    │   %4 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %3)::Tuple{Vararg{Float64}}           │ 
    │   %9 = Core._apply_iterate(Base.iterate, Core.tuple, %8, %4)::Tuple{Float64, Vararg{Float64}}                │ 
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
 • %1  = call #13(::Float64,::Float64...)::Tuple{Float64, Float64, Vararg{Float64}}
   %2  = call #19(::Float64,::Float64,::Float64...)::Tuple{Float64, Float64}
   %3  = call #23(::Float64,::Float64,::Float64...)::Tuple{Vararg{Float64}}
   %4  = call #15(::Float64...)::Tuple{Vararg{Float64}}
   ↩

chengchingwen avatar May 15 '22 19:05 chengchingwen