AxisArrays.jl
AxisArrays.jl copied to clipboard
Faster version of permutation() and axisnames()
Hi! I came up with a faster version of permute(). Maybe it lacks some functionality or generality, but I couldn't come up with such a case yet - can you?
using AxisArrays
using BenchmarkTools
# the wooden hammer
@inline @inbounds function check_duplicates(arr)
N = length(arr)
for i in 1:N-1
for j in i+1:N
arr[i] != arr[j] || throw(ArgumentError("duplicate"))
end
end
return nothing
end
function foo_perm(to, from)
length(to) == length(from) || throw(ArgumentError("not same length"))
res = Vector{Int}(undef, length(from))
@inbounds for (i,t) in enumerate(to)
idx = findfirst(from .== t)
idx != nothing || throw(ArgumentError("a not in b"))
res[i] = idx
end
check_duplicates(res)
return res
end
to=(:c,:w,:h,:d)
from=(:c,:h,:w,:d)
foo_perm(to, from) == AxisArrays.permutation(to, from)
@btime foo_perm($to, $from) # 42ns
@btime AxisArrays.permutation($to, $from) # 307ns
I think axisnames() is typically called with an AxisArray as argument, but I found it a bit slow. Again I came up with another solution which might lack generality, but is faster:
axname(a::AxisArrays.Axis{name}) where name = name
axnames(a::AxisArray) = axname.(a.axes)
a = AxisArray(rand(3,4,5), :c,:h,:w)
axnames(a) == axisnames(a)
@btime axnames($a) # 340ns
@btime axisnames($a) # 4us