Skye Wanderman-Milne
Skye Wanderman-Milne
Hi, sorry for the delay on this! I filed an internal XLA bug (b/187117499 for Googlers) and will keep this thread updated.
Hi sorry I was on vacation!
Also can you provide example command + output in the PR description?
Can you try running `JAX_DEBUG_LOG_MODULES=jax._src.xla_bridge python -c 'import jax; print(jax.devices())'` and paste the output here?
Ah. This is supposed to be raised as an exception instead of falling back to CPU. That functionality must have regressed. Now to figure out why...
This is the assert that fires btw: https://github.com/jax-ml/jax/blob/bbcc3eef3c1fedda3c0eef48c8bd49fd34a313c9/jax/_src/custom_partitioning.py#L462 @pschuh do you know why we don't allow consts in the custom partitioning jaxpr?