brax
brax copied to clipboard
Improve the performance of broad-phase culling on TPU
Brax has a broad-phase culling algorithm that works pretty well on GPU: the Brax Multi-Agent colab demonstrates a scene with many bodies all interacting.
One challenge with broad-phase is that it introduces scatter/gather operations that are not efficient on TPU because TPU lacks the fine-grained memory access semantics that GPU has. We should investigate whether it's helpful to use jax.experimental.host_callback to isolate the few operations that are slow on TPU, and run them on CPU instead.