array-api
array-api copied to clipboard
Multi-device support meta-thread
This is a tracker of the current state of support for more than one device at once in the Array API, its helper libraries, and the libraries that implement it.
Supporting multiple devices at the same time is typically substantially more fragile than pinning one of the available devices at interpreter level and then using that one exclusively, which typically works as intended.
Array API
- Dictates that the device of the output arrays must always follow from that of the input(s), unless explicitly overridden by a
device=kwarg, where allowed. - Is frequently misinterpreted when it comes to priority of input arrays vs. global/context device: https://github.com/data-apis/array-api/pull/919
- There is controversy on what
__array_namespace_info__().default_device()should return: https://github.com/data-apis/array-api/issues/835
array-api-strict
- Supports three hardcoded devices, "cpu", "device1", "device2". This is fit for purpose for testing downstream bugs re. device propagation.
array-api-tests
- Devices are untested: https://github.com/data-apis/array-api-tests/issues/302
array-api-compat
- Adds device param to numpy 1, cupy, torch, and dask (read below).
- Implements helper functions
device()andto_device()to work around non-compliance of wrapped libraries
array-api-extra
- Full support and testing for non-default devices, using array-api-strict only. Actual support from real backends entirely depends on the below.
NumPy
- It supports a single dummy device, "cpu".
-
array-api-compatbackports it to NumPy 1.x.
CuPy
- Non-compliant support for multiple devices.
- array-api-compat adds a dummy
device=parameter to functions. - A compatibility layer is being added at the moment of writing by https://github.com/data-apis/array-api-compat/pull/293. [EDIT] it can't work, as array-api-compat can't patch methods.
- As it doesn't have a "cpu" device, it's impossible to test multi-device ops without access to a dual-GPU host.
PyTorch
- Fully supported (with array-api-compat shims)
- However there's a bug that hampers testing on GPU CI: https://github.com/pytorch/pytorch/issues/150199
JAX
- Bugs in
__array_namespace_info__: https://github.com/jax-ml/jax/issues/27606 - Inside
jax.jit, input-to-output device propagation works, but it's impossible to call creation functions (empty,zeros,full, etc.) on a non-default device: https://github.com/jax-ml/jax/issues/26000
Dask
- Dask doesn't have a concept of device
- array-api-compat adds stub support, that returns "cpu" when wrapping around numpy and a dummy
DASK_DEVICEotherwise. Notably, this is stored nowhere and does not survive a round-trip (device(to_device(x, d) == dcan fail). - This is a non-issue when wrapping around numpy, or when wrapping around cupy with both client and workers mounting a single GPU.
- Multi-GPU Dask+CuPy support could be achieved by starting separate worker processes on the same host and pinning the GPU at interpreter level. This is extremely inefficient as it incurs in IPC and possibly memory duplication. If a user does so, the client and array-api-compat will never know.
-
dask-cudamay improve the situation (did not investigate).
SciPy
- Largely untested. Initial attempt to test: https://github.com/scipy/scipy/pull/22756
At yesterday's consortium meeting, everyone was in agreement that in
with cp.cuda.Device(1):
y = cp.asarray(1, device=cp.cuda.Device(0)) # y should be on Device(0)
z = y + 1 # device of z?
z should be on device 0. #919 makes that clear. However, CuPy maintainers were absent and did not get an opportunity to voice their opinion.