WIP: Implement broadcasting with AxisArrays on Julia 0.7
This PR proposes an implementation of broadcasting for AxisArrays that will be possible using Julia 0.7. I'm getting a bit ahead of myself because not all AxisArrays tests pass on 0.7, and I'm also aware that the new broadcasting API may continue to change (e.g. https://github.com/JuliaLang/julia/pull/25377). However, broadcasting is important enough for how I intend to use AxisArrays that I want to give an early demo, and also want to solicit some feedback before I sink too much time into this approach.
High-level description
- Broadcasting with AxisArray and scalar args works
- Broadcasting with both AxisArray and AbstractArray args is permitted if the array dimensions are compatible. In other words, we assume the user knows what they are doing, and that the axes are compatible semantically so long as indices match, even if there's no axis name/values available on the generic array.
- Broadcasting with different AxisArrays should require the axis names to match in their corresponding dimensions, and also the axis values need to be the same. No auto alignment. This rule is relaxed slightly if an axis is considered to be default, like
Axis{:row}(Base.OneTo(1))in a 1xN matrix. - Axes that are determined to be "default" are allowed to grow if broadcasting demands a different (larger) output shape.
Algorithm description
The following discussion relies upon understanding the new broadcasting API described in the interfaces section of the latest Julia docs. Broadcasting is intercepted after styles are combined, but before eltypes and indices are computed.
-
combine_indicesfrom any AxisArrays (but not other kinds of arrays) in the broadcasting operation. AxisArray axis names and values are returned from a newbroadcast_indicesmethod. As currently implemented, this demands exact equality of axis values, so tiny floating-point differences count. This returns a tuple of AxisArrays.Axis that we'll callaxesAs. -
Provided that was successful, do broadcasting over all broadcast args using the underlying arrays (
array.dataifarrayis an AxisArray). Call the resultbroadcasted. -
Compare the axes in
axesAswith thedefault_axesforbroadcasted(which is not an AxisArray). We'll call the tuple of default axesdefaxesBs. Note thatlength(axesAs) <= length(defaxesBs). Process these two tuplesaxesAsanddefaxesBstaking pairs of elementsaxA,axBfrom each usingBase.tail, etc.
3a. If the axis names match, then you need to see if you believe the axis from axA was originally a default axis. This PR makes the decision that if you have an axis like Axis{:row, <:Base.OneTo}, then it was a default axis. If so, return e.g. Axis{:row}(Base.OneTo(length(axB)) so that you resize the default axis to match the size required for broadcasted. If the values are not from Base.OneTo then it is not a default axis, and the arrays cannot be broadcasted.
3b. If the axis names don't match, then there's no need to worry about default axes, just return axA.
- Step 3 yields a tuple of axes. Wrap
broadcastedinto an AxisArray using the axes obtained from step 3. The number of axes you obtain from step 3 may be less than the number of dimensions ofbroadcasted, in which case the AxisArray constructor will usedefault_axesfor the remainder.
Examples
See test/broadcast.jl, more tests/examples to come.
Relation to previous AxisArrays.jl issues and PRs concerning broadcasting
Issue 128
This PR satisfies what @omus considers an ideal solution in https://github.com/JuliaArrays/AxisArrays.jl/issues/128#issuecomment-340018520 (I've sanitized some deprecation warnings):
julia> A = AxisArray([1,2,3], Axis{:asdf}([1.0, 2.0, 5.0]))
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
1
2
3
julia> A .* 2
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
2
4
6
julia> A .== 2
3-element AxisArray{Bool,1,BitArray{1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
false
true
false
PR 54
This PR also doesn't care about argument order, which was a limitation in PR https://github.com/JuliaArrays/AxisArrays.jl/pull/54:
julia> 1 .+ A
3-element AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}:
2
3
4
It also doesn't care if the eltypes are Real, another limitation in https://github.com/JuliaArrays/AxisArrays.jl/pull/54:
julia> AxisArray([1+im, 2+im]) .+ (3.0+4.5im)
2-element AxisArray{Complex{Float64},1,Array{Complex{Float64},1},Tuple{Axis{:row,Base.OneTo{Int64}}}}:
4.0 + 5.5im
5.0 + 5.5im
Note that broadcasting is not oblivious to the underlying storage order, as mentioned in the high-level description, and there are differing opinions on that [1] [2]. However, this PR is very conservative, in that you can do strictly more with broadcasting while preserving the AxisArray wrapper. If there were another PR that paid no attention to the underlying storage order / did auto alignment, I think you would again have strictly more functionality, for some sense of the word strictly :) I'm not sure how broadcasting should be treated when combining both AxisArrays and AbstractArrays in that case; there you kind of need to pay attention to the storage order.
Known limitations
-
Not everything is inferable yet, trying to identify why.
-
Some of the error messages are opaque when broadcasting doesn't work for AxisArray-specific reasons. I don't think this is insurmountable but it would require some more boiler-plate to fix.
-
~~Axis info can get lost when using wrappers around AxisArrays, like with adjoint.~~ This has been resolved as follows:
julia> A'
1×3 Adjoint{Int64,AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}}:
1 2 3
julia> A' .+ [10,20,30]
3×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
11 12 13
21 22 23
31 32 33
julia> A' .+ AxisArray([10,20,30])
3×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
11 12 13
21 22 23
31 32 33
julia> A' .+ [10 20 30]
1×3 AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:row,Base.OneTo{Int64}},Axis{:asdf,Array{Float64,1}}}}:
11 22 33
julia> A' .+ A'
1×3 Adjoint{Int64,AxisArray{Int64,1,Array{Int64,1},Tuple{Axis{:asdf,Array{Float64,1}}}}}:
2 4 6
Note that as a consequence of requiring unique axis names for each dimension, A + A' fails. This is because the result array would have the same axis name for both column and row (:asdf). At first I wondered if Adjoint should really wrap AxisArrays like it does now, but that's actually consistent with this PR in that the underlying storage order is important in broadcasting. I think I'm fine with that— perhaps the README should say specifically that indexing can be oblivious to the storage order of the underlying array.
- Take a look at
transpose(A):
julia> transpose(A)
1×3 AxisArray{Int64,2,Transpose{Int64,Array{Int64,1}},Tuple{Axis{:transpose,Base.OneTo{Int64}},Axis{:abc,Array{Float64,1}}}}:
1 2 3
Probably AxisArrays should be updated to use the Transpose type that was introduced.
I came across this package as I was looking for some "professional" solution instead of my own stupid hack, and it looks really good. It would be great to have AxisArray-preserving broadcasting, so what's the status of this PR?
What's blocking this? Is there a fundamental problem, or maybe just some busywork with tests implementing tests? It seems like a solution to #156 and maybe #128.