jax icon indicating copy to clipboard operation
jax copied to clipboard

Provide better error message when using device_put_sharded inside jitted code

Open GeorgOstrovski opened this issue 4 years ago • 1 comments

This currently results in

TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

as opposed to regular device_put which doesn't do anything when in jitted code but doesn't fail either. The message could be more helpful here.

GeorgOstrovski avatar Mar 25 '21 17:03 GeorgOstrovski

Hi @GeorgOstrovski

To illustrate the use of device_put_sharded within JIT-compiled code, I created an example. When executed with JAX version 0.4.23 on both CPU and GPU,

import jax
import numpy as np
import jax.numpy as jnp

@jax.jit
def sharded_computation():
  devices = jax.local_devices()
  x = [jnp.ones(5) for device in devices]
  y = jax.device_put_sharded(x, devices)
  return np.allclose(y, jnp.stack(x))

result = sharded_computation()
print(result)  

it produced the following error message: XlaRuntimeError: INVALID_ARGUMENT: Not supported: The C++ jax jit execution path, only accepts DeviceArray, Numpy arrays scalars of supported types (see implementation), or Python scalars. Got type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

However, when running the same code with JAX version 0.3.25 on TPU, it still produced the following error message, as you mentioned: TypeError: No canonicalize_dtype handler for type: <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

For further details, please refer to the provided gists for CPU/GPU and TPU executions.

selamw1 avatar Feb 15 '24 00:02 selamw1