`jax.make_array_from_async_callback`
jax.make_array_from_callback does a sequential loop over devices attached to this host: https://github.com/google/jax/blob/ef40b85c8b2686f64bc9ca67de267a6b1a7935bb/jax/_src/array.py#L693. When fetching from remote storage with high latency, this sequential loop can become latency-limited rather than throughput-limited on the network connection. In that case, it's typically a latency improvement to be able to issue the network requests for all devices in parallel.
As a user, I do that manually from jax.make_array_from_single_device_arrays, but that's a lower level API. Instead, the ideal would be a new function jax.make_array_from_async_callback that takes an async callback. Besides changing the type signature, the only change in implementation would be to add await asyncio.gather(*...) to https://github.com/google/jax/blob/ef40b85c8b2686f64bc9ca67de267a6b1a7935bb/jax/_src/array.py#L693, making it:
per_device_values = await asyncio.gather(*[data_callback(device_to_index_map[device])
for device in devices])
Then jax.make_array_from_callback could wrap the async version.
Is this what you are looking for? https://github.com/google/jax/blob/main/jax/experimental/array_serialization/serialization.py#L67-L78
You can write a wrapper like this for your code base too! Or if the above utility is helpful and what you were looking for, I can expose that.
This seems quite close to #10897