array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Automatically use the correct device in xp.clip with passed Python number literal as bounds

Open ogrisel opened this issue 1 year ago • 5 comments

I would like the following not to fail with PyTorch:

>>> import array_api_compat.torch  as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
  Cell In[4], line 1
    xp.clip(data, 0.1, 0.9)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
    ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:

>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')

ogrisel avatar Aug 08 '24 14:08 ogrisel

Note that I have not investigated if other array API namespaces suffer from the same problem.

ogrisel avatar Aug 08 '24 15:08 ogrisel

Ah, I completely forgot about devices when I wrote this wrapper.

asmeurer avatar Aug 08 '24 18:08 asmeurer

BTW, is this something that should be made explicit in the spec itself? Or would that just make the spec unnecessarily verbose?

Maybe it could just be tested in array-api-tests.

ogrisel avatar Aug 09 '24 08:08 ogrisel

I think it's covered by the general design principles, e.g., https://data-apis.org/array-api/latest/design_topics/device_support.html, "This standard chooses to add support for method 3 (local control), with the convention that execution takes place on the same device where all argument arrays are allocated"

We could add a bullet point to that bullet point list that Python scalars and other such non-array-library objects should not influence device assignment.

And +1 for a test.

rgommers avatar Aug 09 '24 10:08 rgommers

Yes, unfortunately, device support is not tested at all in the test suite right now.

asmeurer avatar Aug 09 '24 18:08 asmeurer

Note that the dtype should similarly be induced from the first argument to avoid unwanted up-casting. That is, I would like the following to hold automatically:

a = xp.linspace(-1, 1, 10, dtype=xp.float32)
assert xp.clip(a, 0, 1).dtype == xp.float32

EDIT: it seems that dtype handling is part of #166.

ogrisel avatar Sep 02 '24 12:09 ogrisel

The clip wrapper has been a little annoying to get right, primarily because of the "no promotion" behavior plus the fact that it accepts scalars. But I've hopefully ironed out all the wrinkles in #166 (except for a minor known issue that dask won't handle some cases with uint64 arrays correctly because of the way NumPy upcasts to float64).

asmeurer avatar Sep 03 '24 20:09 asmeurer