array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Multi-device support meta-thread

Open crusaderky opened this issue 10 months ago • 1 comments

This is a tracker of the current state of support for more than one device at once in the Array API, its helper libraries, and the libraries that implement it.

Supporting multiple devices at the same time is typically substantially more fragile than pinning one of the available devices at interpreter level and then using that one exclusively, which typically works as intended.

Array API

  • Dictates that the device of the output arrays must always follow from that of the input(s), unless explicitly overridden by a device= kwarg, where allowed.
  • Is frequently misinterpreted when it comes to priority of input arrays vs. global/context device: https://github.com/data-apis/array-api/pull/919
  • There is controversy on what __array_namespace_info__().default_device() should return: https://github.com/data-apis/array-api/issues/835

array-api-strict

  • Supports three hardcoded devices, "cpu", "device1", "device2". This is fit for purpose for testing downstream bugs re. device propagation.

array-api-tests

  • Devices are untested: https://github.com/data-apis/array-api-tests/issues/302

array-api-compat

  • Adds device param to numpy 1, cupy, torch, and dask (read below).
  • Implements helper functions device() and to_device() to work around non-compliance of wrapped libraries

array-api-extra

  • Full support and testing for non-default devices, using array-api-strict only. Actual support from real backends entirely depends on the below.

NumPy

  • It supports a single dummy device, "cpu".
  • array-api-compat backports it to NumPy 1.x.

CuPy

  • Non-compliant support for multiple devices.
  • array-api-compat adds a dummy device= parameter to functions.
  • A compatibility layer is being added at the moment of writing by https://github.com/data-apis/array-api-compat/pull/293. [EDIT] it can't work, as array-api-compat can't patch methods.
  • As it doesn't have a "cpu" device, it's impossible to test multi-device ops without access to a dual-GPU host.

PyTorch

  • Fully supported (with array-api-compat shims)
  • However there's a bug that hampers testing on GPU CI: https://github.com/pytorch/pytorch/issues/150199

JAX

  • Bugs in __array_namespace_info__: https://github.com/jax-ml/jax/issues/27606
  • Inside jax.jit, input-to-output device propagation works, but it's impossible to call creation functions (empty, zeros, full, etc.) on a non-default device: https://github.com/jax-ml/jax/issues/26000

Dask

  • Dask doesn't have a concept of device
  • array-api-compat adds stub support, that returns "cpu" when wrapping around numpy and a dummy DASK_DEVICE otherwise. Notably, this is stored nowhere and does not survive a round-trip (device(to_device(x, d) == d can fail).
  • This is a non-issue when wrapping around numpy, or when wrapping around cupy with both client and workers mounting a single GPU.
  • Multi-GPU Dask+CuPy support could be achieved by starting separate worker processes on the same host and pinning the GPU at interpreter level. This is extremely inefficient as it incurs in IPC and possibly memory duplication. If a user does so, the client and array-api-compat will never know.
  • dask-cuda may improve the situation (did not investigate).

SciPy

  • Largely untested. Initial attempt to test: https://github.com/scipy/scipy/pull/22756

crusaderky avatar Mar 31 '25 12:03 crusaderky

At yesterday's consortium meeting, everyone was in agreement that in

with cp.cuda.Device(1):
    y = cp.asarray(1, device=cp.cuda.Device(0))  # y should be on Device(0)
    z = y + 1  # device of z?

z should be on device 0. #919 makes that clear. However, CuPy maintainers were absent and did not get an opportunity to voice their opinion.

crusaderky avatar Apr 18 '25 09:04 crusaderky