RFC: add `take_along_axis` to take values along a specified dimension
This RFC proposes the addition of a new API in the array API specification for taking values from an input array by matching one-dimensional index and data slices.
Overview
Based on array comparison data, the API is available across most major array libraries in the PyData ecosystem.
take_along_axis was previously discussed in https://github.com/data-apis/array-api/issues/177 as a potential standardization candidate and has been mentioned in downstream usage. As indexing with multidimensional integer arrays (see https://github.com/data-apis/array-api/issues/669) is not yet supported in the specification, the specification lacks a means to concisely select multiple values along multiple one-dimensional slices. This RFC aims to fill this gap.
Additionally, even with advanced indexing, replicating take_along_axis is more verbose and trickier to get right. For example, consider
In [1]: import numpy as np
In [2]: a = np.array([[10,30,20], [60,40,50]])
In [3]: a
Out[3]:
array([[10, 30, 20],
[60, 40, 50]])
In [4]: indices = np.array([[2,0,1],[1,2,0]])
In [5]: indices
Out[5]:
array([[2, 0, 1],
[1, 2, 0]])
In [6]: np.take_along_axis(a, indices, axis=1)
Out[6]:
array([[20, 10, 30],
[40, 50, 60]])
To replicate with advanced indexing,
In [7]: i0 = np.arange(a.shape[0])[:, np.newaxis]
In [8]: a[i0, indices]
Out[8]:
array([[20, 10, 30],
[40, 50, 60]])
where we need to create an integer index (with expanded dimensions) for the first dimension, which can then be broadcast against the integer index indices. Especially for higher order dimensions, replication of take_along_axis becomes even more verbose. E.g., for a 3-dimensional array,
a = np.random.rand(3, 4, 5)
indices = np.random.randint(5, size=(3, 4, 2))
i0 = np.arange(a.shape[0])[:, np.newaxis, np.newaxis]
i1 = np.arange(a.shape[1])[np.newaxis, :, np.newaxis]
result = a[i0, i1, indices]
In general, while "advanced indexing" can be used for replicating take_along_axis, doing so is less ergonomic and has a higher likelihood of mistakes.
Prior art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
- Added in 1.15.
- Discussion: https://github.com/numpy/numpy/issues/6078
- Discussion: https://github.com/numpy/numpy/issues/8708
- PR: https://github.com/numpy/numpy/pull/11105
- CuPy: https://docs.cupy.dev/en/stable/reference/generated/cupy.take_along_axis.html
- Dask: (does not currently implement)
- Issue: https://github.com/dask/dask/issues/3663
- PR: https://github.com/dask/dask/pull/11076
- PyTorch: https://pytorch.org/docs/stable/generated/torch.take_along_dim.html
- Named
take_along_dim, rather thantake_along_axis
- Named
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/take_along_axis
- JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.take_along_axis.html
- Supports additional
modeandfillkwargs
- Supports additional
Proposal
def take_along_axis(x: array, indices: array, /, axis: int) -> array
Notes
- For the Array API compat layer, for libraries without this functionality, the workaround proposed in https://github.com/numpy/numpy/issues/8708#issue-210637557 and implemented in https://github.com/numpy/numpy/pull/11105 can be done in pure Python, without needing any additional Array API functions.
Questions
- NumPy and its kin allow
axisto beNonein order to indicate that the input arrayxshould be flattened prior to indexing. This allows consistency with NumPy'ssortandargsortfunctions. However, the specification forsortandargsortdoes not supportNone(i.e., flattening). Accordingly, this RFC does not propose supportingaxis=None. Are we okay with this? - NumPy and kin allow keyword and positional arguments. This RFC makes
xandindicespositional-only and allowsaxisto be both positional or keyword. Any concerns? - As elsewhere with this specification, presumably PyTorch will be okay aliasing
take_along_dimastake_along_axisanddimasaxisto ensure spec compliance?
Thanks @kgryte
Accordingly, this RFC does not propose supporting
axis=None. Are we okay with this?
That makes perfect sense to me. The along_axis in the name would be a bit meaningless if the operation isn't happening along an axis in the end:)
This RFC makes
xandindicespositional-only and allowsaxisto be both positional or keyword. Any concerns?
Looks like a good choice to me.
presumably PyTorch will be okay
That's consistent with the rest of the design, so I don't see a problem here.
It's used quite rarely at least in SciPy and scikit-learn, and it's fairly new even in NumPy (1.15.0). It looks like there's enough justification for adding take_along_axis, but it'd be valuable to spell that justification out @kgryte, as we just discussed.
While take_along_axis may not be used on its own in SciPy, it is commonly used indirectly. For example, as @mdhaber mentioned offline, it is used in rankdata (see https://github.com/scipy/scipy/pull/20639), which in turn appears in a variety of important statistical tests, such as wilcoxon, mannwhitneyu, kendalltau, and various hypothesis tests. In which case, limiting our search to just direct usage may not paint the full picture.
Regardless, I've updated the OP with some additional justification. Namely, replicating take_along_axis with just fancy indexing is harder to get right and more verbose. IMO, there are definite ergonomic benefits to take_along_axis.
@kgryte The documentation states that "The behavior for out-of-bounds indices is left unspecified.", but I do understand correctly that some negative indices are in bounds as described in Single-axis indexing??
@mdhaber Yes, we need to update the spec to explicitly include language for allowing negative indices, as done elsewhere for other APIs.