RFC: add support for a tuple of axes in `expand_dims`
Hello all! I raised this issue on array-api-compat earlier (https://github.com/data-apis/array-api-compat/issues/105), but I think it might be more properly directed here.
In the array API, expand_dims supports only a single axis (https://data-apis.org/array-api/latest/API_specification/generated/array_api.expand_dims.html) as opposed to a tuple of axes. This is different to NumPy, CuPy, and Jax, which support a tuple of axes. PyTorch, however, supports only a single axis. I don't know the justification for why the array API supports only a single axis as opposed to a tuple, but what it means is that expand_dims no longer works in many places when adopting the array API.
In practice, expand_dims is just a light wrapper for reshape, see https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L594. But it's not great to force users to write their own version of expand_dims in every library now. Is the array API willing to update expand_dims to support a tuple of axes? If not, and if expand_dims will only support a single axis going forward, that effectively makes all users of expand_dims copy and paste the NumPy implementation.
@lucascolley Pointed out to me that when expand_dims was added to the array API, only NumPy supported a tuple of axes. See https://github.com/data-apis/array-api/pull/42. That was 4 years ago and the situation has changed, as above.
Seems tuple support was omitted because torch doesn't support it https://github.com/data-apis/array-api/pull/42. I found a few feature requests for it for torch.unsqueeze (the PyTorch equivalent to expand_dims) https://github.com/pytorch/pytorch/issues/30702, https://github.com/pytorch/pytorch/pull/4692#issuecomment-394927742. Seems it was intentionally omitted due to the ambiguity that arises from mixing negative and positive indices.
I agree this ambiguity is a potential concern. If we standardize this, we should somehow only require a subset of behavior that omits this ambiguity, e.g., by leaving the mixing of negative and nonnegative indices unspecified.
Consider for example:
>>> np.expand_dims(np.empty((2,)), (1, -1)).shape
(2, 1, 1)
The resulting shape has 1 in positions 1 and -1, but a result shape of (2, 1) would also satisfy this. I suppose one could argue that exactly len(axes) dimensions should be added.
But also consider
>>> np.expand_dims(np.empty((2, 3, 4, 5)), (3, -3)).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/lib/shape_base.py", line 597, in expand_dims
axis = normalize_axis_tuple(axis, out_ndim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/miniconda3/envs/array-apis/lib/python3.11/site-packages/numpy/core/numeric.py", line 1385, in normalize_axis_tuple
raise ValueError('repeated axis')
ValueError: repeated axis
There's no way to insert 1 dimensions into (2, 3, 4, 5) so that they appear at indices 3 and -3.
Here's a small proof. There's no length list where you can remove indices 3 and -3 and result in a list of length 4
>>> def remove_indices(n, idxes):
... """Return range(n) with `idxes` indices removed"""
... x = list(range(n))
... vals = [x[i] for i in idxes]
... for v in vals:
... try:
... x.remove(v)
... except ValueError: # Already removed
... pass
... return x
>>> [remove_indices(n, (-3, 3)) for n in range(4, 10)]
[[0, 2], [0, 1, 4], [0, 1, 2, 4, 5], [0, 1, 2, 5, 6], [0, 1, 2, 4, 6, 7], [0, 1, 2, 4, 5, 7, 8]]
>>> [len(remove_indices(n, (-3, 3))) for n in range(4, 10)]
[2, 3, 5, 5, 6, 7]
At the same time, if the goal of expand_dims is for the axes to refer to the dimensions after unsqueezing/expanding, then it's not exactly trivial to do it as a sequence of expand_dims, because if you apply the expansion in the wrong order you will break the position of previous dimensions (the correct logic is not hard, but it's the sort of thing that's easy to get wrong). So I think there is value in having native support for multiple axes.
Regarding removing ambiguity, I think it would suffice to impose an ordering in which to prefer expanding dims right? For example, if we specify "negative indices get resolved first" then your borrowing your example above could be resolved as
x = np.empty((2, 3, 4, 5))
xp.expand_dims(x., (3, -3)) == np.expand_dims(np.expand_dims(x, -3), 3)
so that the final output shape is (2, 3, 1, 1, 4, 5), which seems reasonable.
Still, I'm not sure if it is worth it since in the first place users could do it in a two-step expansion (albeit with some more thought), and the resolution order (+ or - indices first?) is rather arbitrary.
When you do repeated expand_dims, the inserted dimensions in the final shape won't necessarily be in the indices you initially specified (that's the whole point of this feature request, that you need a way to do them all at once). (2, 3, 1, 1, 4, 5) has 1s at indices 2 and -3 (remember 0-based indexing), because the 1 that was at index -3 got shifted over.
in case it affects this making v2024 either way, this is now available as https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html
Given the ambiguity of supporting a tuple of axes in expand_dims, I wonder if there is room for an alternative API which avoids the ambiguity altogether. Namely,
def spread_dims(x: array, ndims: int, axes=Tuple[int, ...]) -> array
which expands the shape of an input array x to have ndims and where the current dimensions are explicitly mapped to a unique list of axes in the resulting array. All unspecified axes must be singleton dimensions.
This essentially flips the problem into one in which you specify where you want the non-singleton dimensions, rather than where you want to insert the singleton dimensions.
That sounds like a good idea if anybody takes issue with the interpretation chosen for xpx.expand_dims. But the milestone can probably be bumped (or removed until that happens).
It seems to me the behavior of expand_dims with multiple axes could be specified without any of the ambiguities mentioned above. Basically, for y = expand_dims(x, axes) when axes is a tuple:
- the output
ymust have dimensionx.ndim + len(axes) - each entry of
axesmust be unique, with negative indices normalized in relation toy.ndim(notx.ndim), or else aValueErroris raised. - for each entry
axisofaxes,y.shape[axis]must be1. - remaining dimensions of
yconsist of the dimensions ofxin order.
This basically describes the existing behavior of NumPy, and handles all the ambiguities mentioned above:
- if
xhas shape(2, 3, 4, 5), thenexpand_dims(x, (3, -3))fails because the axes list has duplicate entries (they are normalized to(3, 3)); aValueErrorshould be raised. - if
xhas shape(2,), thenexpand_dims(x, (-1, 1))will have shape(2, 1, 1), as the indices are normalized to(2, 1)
This behavior is semantically equivalent to calling expand_dims repeatedly with a single axis, only when the axes tuple is normalized to positive values using the final shape, is sorted, and contains no duplicates.