Replace `int` with `SupportsIndex` in indexing methods hints
Resolves https://github.com/data-apis/array-api/issues/383
Thanks @honno.
Adding a quick clarifying example:
>>> import numpy as np
>>> import torch
>>>
>>> x = np.arange(1, 4)
>>> y = torch.arange(1, 4)
>>>
>>> class Ix:
... def __index__(self):
... return 1
...
>>> x[Ix()]
2
>>> y[Ix()]
tensor(2)
@honno would you be able to open a PR with a test for this to array-api-tests? Doesn't have to be merged before this will be merged, but that will help smoke out if any known/tested library does not yet support this feature.
@honno would you be able to open a PR with a test for this to
array-api-tests? Doesn't have to be merged before this will be merged, but that will help smoke out if any known/tested library does not yet support this feature.
Good shout, I opened https://github.com/data-apis/array-api-tests/pull/247 to check this all out. From the looks of it:
- NumPy (+
array_api_strict) and PyTorch supports this behaviour. - JAX doesn't support, has a general note to open a feature requests on indexing modes (... that's prob refering to advance integer indexing).
- cupy seems to only support advance integer indexing and doesn't play nice with 0-D integer arrays as indexes.
- Dask at least accepts 0-D integer input, but
array_api_compat.dasknot playing nice with the test suite so I'll have to triage that and explore further.
but array_api_compat.dask not playing nice with the test suite so I'll have to triage that and explore further.
Let me know what you find out. We do run it on CI with some skips and xfails, and also the max-examples is set to 5. I haven't looked at the Dask xfails too closely, and obviously if we can remove any of those that would be great. CC @lithomas1
@honno Were you able to triage the Dask issues with the test suite?
Got Dask working[^1]—the latest release at least doesn't support indexables for both get and set items right now (TypeError for an internal inequality check which assumes indexes as ints).
[^1]: Turns out I should of been using array_api_compat.dask.array heh, although there's another unrelated issue Dask has with one of our utilities I'll have to explore.
By the way, for implementers, the generally correct behavior is to operator.index() to normalize index objects that aren't one of the other supported index types like Ellipsis, slice, or array. My guess is that dask.array isn't doing that.
@honno Is there anything more that we need to do with this PR?
@honno Is there anything more that we need to do with this PR?
PR I think I'm happy with, just be mindful that last I checked in March it was only NumPy and PyTorch that supported "indexables", whereas JAX/CuPy/Dask didn't, so they'd need updating to support this.
Example of what I mean by an indexable:
class AwkwardIndexable:
def __init__(self, value: int):
self._value = value
def __int__(self):
raise TypeError("__int__() should not be called")
def __index__(self):
return self._value