clarify if `__array_namespace_info().default_device()` can be None
The spec only says it returns an object corresponding to the default device. ( https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.default_device.html#array_api.info.default_device)
jax.numpy returns None, so the question is whether None corresponds to the default device or not.
In [5]: import jax.numpy as jnp
In [6]: jnp.__array_namespace_info__().default_device() is None
Out[6]: True
Hi - for what it's worth, we made the deliberate choice to return None here in order to make JAX's existing device placement semantics work with the specifications of the array API standard.
The problem is that JAX's existing device placement does not entirely align with the model that the authors of the spec had in mind. For example, under JIT, there is no default device, because the array referenced in the Python API may not ever physically exist. Here's a silly example:
@jax.jit
def f(x):
y = jnp.arange(10)
return x
What device is y on here? That question cannot be answered, because the compiler will recognize that y = jnp.arange(10) is dead code, and will eliminate this from the program: y in this program will never exist as an actual buffer on a device.
Let's modify this slightly:
@jax.jit
def f(x):
y = jnp.arange(len(x))
return x + y
What device will y be on now? Here this will be determined contextually by the compiler: because y only interacts with x, the compiler will allocate its buffer on the same device (or devices) as x.
Neither of these situations is compatible with the idea of a global default device, and so the very notion of "default_device" as envisioned by the array api specification is flawed, and not applicable to frameworks like JAX. Given that, we thought returning None from default_device would be the least bad approach. After all None is a valid argument to device in all cases, and explicitly passing device=None results in the same behavior as not passing device at all – that behavior seemed to align with the notion of a "default".
If you have other suggestions, I'm open to hear them! If the specification were changed such that default_device could not return None, I suppose our best option would probably be to define some NoDevice singleton that has the same semantics as None does currently.
IMHO, I think I'm personally happy for JAX to return None as the default device. But yes, I think it needs to be spelled out explicitly.
def f(x): y = jnp.arange(len(x)) return x + yWhat device will
ybe on now? Here this will be determined contextually by the compiler: becauseyonly interacts withx, the compiler will allocate its buffer on the same device (or devices) asx.
It's worth pointing out that this behaviour, while definitely desirable and nice to read, is something that's possible exclusively on lazy backends. In fact, the snippet above will crash on PyTorch if x does not lay on the default device. (CuPy has blocking design issues on this). As a result, the current best practice for Array API agnostic functions is
from array_api_compat import array_namespace, device
def f(x):
xp = array_namespace(x)
y = xp.arange(x.shape[-1], device=device(x))
return x + y
The array-api-compat shims are necessary to support NumPy 1.x, Dask, Sparse, and JAX itself.
This pattern follows the guideline of prioritizing input->output propagation over global and context device: https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics
Preserve device assignment as much as possible (e.g. output arrays from a function are expected to be on the same device as input arrays to the function).
@jax.jit def f(x): y = jnp.arange(10) return xWhat device is
yon here? That question cannot be answered
The answer here is that no-one cares. A much more interesting example would be
from array_api_compat import array_namespace, device, to_device
def f(x):
"""Return x+arange, prioritizing the default device over x.device"""
xp = array_namespace(x)
y = xp.arange(x.shape[-1])
return to_device(x, device(y)) + y
Here, we have some peculiar behaviour:
- PyTorch returns an object on whatever was set with
torch.set_default_device; - Eager JAX returns
jax_default_device; - Jitted JAX currently crashes because
to_devicedoesn't expectdeviceto return None, which is the array-api-compat hack around https://github.com/jax-ml/jax/issues/26000. Realistically though it would make sense if it returnedx.device.
To me None seems like a good choice in the jax context and like a perfectly "legal" choice within the array API. Because None is an object corresponding to the default device.
It is also not a problem that the jax-onic code gives the compiler more options than the code you have to write to conform to the array API.
# jax-onic code
@jax.jit
def f(x):
y = jnp.arange(len(x))
return x + y
# array API compatible
def f(x):
xp = array_namespace(x)
y = xp.arange(x.shape[-1], device=device(x))
return x + y
Though I do wonder if in the array API version of f(x) the call to device(x) could return None - basically indicating "do what you want with the placement of y"? I think in vanilla Python it would not be possible to do that, but maybe with a @jit decorator it is?
Either way, my main point is that it is fine that "not all valid jax/numpy/cupy code is valid array API code".
Though I do wonder if in the array API version of
f(x)the call todevice(x)could returnNone- basically indicating "do what you want with the placement ofy"?
It's what it does today.
It looks like everyone is in agreement (as am I) that returning None is fine. We had a look at this issue in the community meeting, and agreed that this issue can be closed by adding to the spec that returning None is allowed, with a note that that may be useful if the default device isn't predictable due to, for example, JIT behavior or device placement rules implemented by the library.