Provide better error message when using device_put_sharded inside jitted code
This currently results in
TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
as opposed to regular device_put which doesn't do anything when in jitted code but doesn't fail either. The message could be more helpful here.
Hi @GeorgOstrovski
To illustrate the use of device_put_sharded within JIT-compiled code, I created an example. When executed with JAX version 0.4.23 on both CPU and GPU,
import jax
import numpy as np
import jax.numpy as jnp
@jax.jit
def sharded_computation():
devices = jax.local_devices()
x = [jnp.ones(5) for device in devices]
y = jax.device_put_sharded(x, devices)
return np.allclose(y, jnp.stack(x))
result = sharded_computation()
print(result)
it produced the following error message:
XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
However, when running the same code with JAX version 0.3.25 on TPU, it still produced the following error message, as you mentioned:
TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
For further details, please refer to the provided gists for CPU/GPU and TPU executions.