array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Add `unstack`?

Open shoyer opened this issue 3 years ago • 3 comments

TensorFlow and PyTorch have unstack() functions (Torch calls it unbind()) for converting an array into a Python sequence of arrays, unpacked along a dimension:

>>> torch.unbind(torch.tensor([[1, 2, 3],
>>>                            [4, 5, 6],
>>>                            [7, 8, 9]]))
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))

I think this could potentially make sense in the API standard, especially because unlike in NumPy you cannot iterate over the first axis of arrays: https://github.com/data-apis/array-api/issues/481

shoyer avatar Sep 23 '22 16:09 shoyer

This does seem useful indeed. NumPy kinda has this as well, all the split APIs (split, array_split, hsplit, dsplit, vsplit):

>>> y = np.arange(9).reshape(3, 3) + 1
>>> y
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])
>>> np.split(y, y.shape[0])
[array([[1, 2, 3]]), array([[4, 5, 6]]), array([[7, 8, 9]])]

There's two things wrong with np.split:

  • The more general "split into equal chunks" isn't very useful,
  • It returns a list rather than a tuple

I agree that unstack is the most logical name here, to match stack.

rgommers avatar Sep 25 '22 19:09 rgommers

Since NumPy uses iteration along the first axes, I think the "obvious" solution is tuple(arr) or list(arr). Since these behave different from splitting (they remove/unpack the first axes completely, rather than splitting/chunking it up).

As mentioned on the other issue, I also like iteraxis (or similar). The difference beeing that the result is an iterate rather than a tuple/list.

The other addition that I like is to add axis=. For a single axis, this is not particularly important (you can also moveaxis first; although it is still convenient I think). The main advantage is making it easier to iterate/work with multiple axes.

Another related thing if we look from the iterator point of view, is the np.ndenumerate API which tracks an index. That name might also suggest nditer as a good name (NumPy already has that, but it is used for way too complicated/specific API).

Tracking an index could also be an optional argument. Although I suppose at this point, the array API may want to go with a minimal solution (i.e. wait for library need).

seberg avatar Sep 26 '22 10:09 seberg

Since NumPy uses iteration along the first axes, I think the "obvious" solution is tuple(arr) or list(arr).

That's kind of circular reasoning - for a new user that doesn't already know it will split along the first axis, I don't think it is possible to predict what that will do.

As mentioned on the other issue, I also like iteraxis (or similar). The difference beeing that the result is an iterate rather than a tuple/list.

This is a decent alternative I think. There are some pros and cons to returning an iterator rather than a tuple, but it's very similar to unstack otherwise. unstack may still be a good name even if it returns an iterator.

This also brings up the question whether or not stack should accept an iterator as input.

Another related thing if we look from the iterator point of view, is the np.ndenumerate API which tracks an index. That name might also suggest nditer as a good name (NumPy already has that, but it is used for way too complicated/specific API).

Tracking an index could also be an optional argument. Although I suppose at this point, the array API may want to go with a minimal solution (i.e. wait for library need).

These APIs are rarely used today, and imho overly complicated. They seem like a poor fit for the standard.

rgommers avatar Sep 27 '22 10:09 rgommers

Commenting here, since Leo brought that up and it struck me as a possible usability issue? If unpack returns a view on libraries that support it, but a copy on others (or copy on write?). Are the use-cases we have in mind hampered by not knowing what the actual implementation will do?

seberg avatar Oct 06 '22 17:10 seberg

As https://github.com/data-apis/array-api/issues/481#issuecomment-1256430676 discusses, unstack is a convenience API at this point. It also helps for design symmetry, since we have stack.

The current way of doing this with a comprehension:

>>> x = np.arange(6).reshape((2, 3))
>>> axis = 0
>>> tuple(x[i, ...] for i in range(x.shape[axis]))
(array([0, 1, 2]), array([3, 4, 5]))
>>> axis = 1
>>> tuple(x[:, i, ...] for i in range(x.shape[axis]))
(array([0, 3]), array([1, 4]), array([2, 5]))

Doing this in a generic way with an axis keyword is not a trivial convenience function though. The JAX unstack implementation gives a good idea:

def _unstack(x):
  return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])]

def index_in_dim(operand: Array, index: int, axis: int = 0,
                 keepdims: bool = True) -> Array:
  """Convenience wrapper around slice to perform int indexing."""
  index, axis = core._canonicalize_dimension(index), int(axis)
  axis_size = operand.shape[axis]
  wrapped_index = index + axis_size if index < 0 else index
  if not 0 <= wrapped_index < axis_size:
    msg = 'index {} is out of bounds for axis {} with size {}'
    raise IndexError(msg.format(index, axis, axis_size))
  result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis)
  if keepdims:
    return result
  else:
    return lax.squeeze(result, (axis,))

unstack may still be a good name even if it returns an iterator.

@seberg said yesterday that he is fine with unstack. iteraxis is still potentially interesting as a way to iterate over axes, but is probably a separate API.

So it looks like adding unstack is fine to move ahead with, unless there are more concerns?

Are the use-cases we have in mind hampered by not knowing what the actual implementation will do?

This should not be specific to unstack. A view isn't a separate concept in the standard, and it should remain that way. There's lots of previous discussion on this repo about view and mutability, which resulted in https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html.

rgommers avatar Oct 07 '22 10:10 rgommers