Skye Wanderman-Milne

Results 66 comments of Skye Wanderman-Milne

@gnecula Unstarted. We still need to do the XLA plumbing to allow setting the `--xla_force_host_platform_device_count` flag programmatically. @neerajprad Sorry for not replying earlier! What do you mean by initializing across...

I don't think that would be simpler to implement. Beyond the XLA changes you mentioned, we'd also need to teach pmap that the CPU backend doesn't use multiple devices, and...

Ah ok, I didn't realize you'd also like multiple threads on each GPU (sorry if I forgot from NeurIPS). I don't think XLA has any interface for this currently, and...

Just to confirm, you're running four separate Python processes with the four different IP addresses right? You should also set `config.FLAGS.jax_xla_backend = "tpu_driver"` if you're not already. (I'm not sure...

Glad to hear it! Please let us know if you run into anything else, or have questions or suggestions! Actually, did you already get ppermute working too? Feel free to...

I think you're close! Everything you're saying sounds right. I think your issue is that on each host, you're only providing that host's list of permutations to `ppermute`. You should...

I just looked at your stacktrace more carefully, and it looks like you're already providing the full permutation list, but you have something like while_loop(pmap(ppermute)). I don't expect a `while_loop`...

Another thought: be very careful running a `ppermute` inside a `while_loop`, because if different hosts end up with different trip counts, your outer `pmap` will hang or return incorrect results....

Are you possibly missing a data dependence from the input of the `pmap` to the `ppermute`? If you're getting the same stack trace that includes `apply_primitive`, it may mean the...