StructArrays.jl
StructArrays.jl copied to clipboard
Supporting `map` on GPU
StructArrays components can be GPU arrays, but many operations on them are not supported yet. For example, a simple map:
using Metal
A = StructArray(a=MtlArray(rand(Float32, 10^3)), b=MtlArray(rand(Int8, 10^3)))
map(exp, A.a) # works
map(x -> exp(x.a), A) # throws "Scalar indexing is disallowed" because `map` falls back to iteration
Would be great to make it work! How do you think could be implemented? I'm not very experienced with GPU arrays and their inner workings.