Automatically use the correct device in xp.clip with passed Python number literal as bounds
I would like the following not to fail with PyTorch:
>>> import array_api_compat.torch as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
Cell In[4], line 1
xp.clip(data, 0.1, 0.9)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
return f(*args, xp=xp, **kwargs)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!
At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:
>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')
Note that I have not investigated if other array API namespaces suffer from the same problem.
Ah, I completely forgot about devices when I wrote this wrapper.
BTW, is this something that should be made explicit in the spec itself? Or would that just make the spec unnecessarily verbose?
Maybe it could just be tested in array-api-tests.
I think it's covered by the general design principles, e.g., https://data-apis.org/array-api/latest/design_topics/device_support.html, "This standard chooses to add support for method 3 (local control), with the convention that execution takes place on the same device where all argument arrays are allocated"
We could add a bullet point to that bullet point list that Python scalars and other such non-array-library objects should not influence device assignment.
And +1 for a test.
Yes, unfortunately, device support is not tested at all in the test suite right now.
Note that the dtype should similarly be induced from the first argument to avoid unwanted up-casting. That is, I would like the following to hold automatically:
a = xp.linspace(-1, 1, 10, dtype=xp.float32)
assert xp.clip(a, 0, 1).dtype == xp.float32
EDIT: it seems that dtype handling is part of #166.
The clip wrapper has been a little annoying to get right, primarily because of the "no promotion" behavior plus the fact that it accepts scalars. But I've hopefully ironed out all the wrinkles in #166 (except for a minor known issue that dask won't handle some cases with uint64 arrays correctly because of the way NumPy upcasts to float64).