Skye Wanderman-Milne

Results 66 comments of Skye Wanderman-Milne

Hi, sorry for the delay on this! I filed an internal XLA bug (b/187117499 for Googlers) and will keep this thread updated.

Also can you provide example command + output in the PR description?

Can you try running `JAX_DEBUG_LOG_MODULES=jax._src.xla_bridge python -c 'import jax; print(jax.devices())'` and paste the output here?

Ah. This is supposed to be raised as an exception instead of falling back to CPU. That functionality must have regressed. Now to figure out why...

This is the assert that fires btw: https://github.com/jax-ml/jax/blob/bbcc3eef3c1fedda3c0eef48c8bd49fd34a313c9/jax/_src/custom_partitioning.py#L462 @pschuh do you know why we don't allow consts in the custom partitioning jaxpr?