IndicesBoundaryMasker in JAX causes halting problem when used in multi-GPU
The current implementation occasionally causes halting between GPUs when padding is applied to the bmap. This issue primarily arises because the function cannot be JIT-compiled.
In the JAX implementation of IndicesBoundaryMasker, there are several operations, such as conditional statements, that are not supported in JIT-compiled JAX. The previous implementation was JIT-compatible and did not encounter these issues, which should serve as the reference for resolving this problem.
can you create a repro please?
Just run any example (e.g., flow over sphere) multiple times with having both GPUs visible. It will halt 90% of the time.
Please use the old implementation to fix this. Thanks.
(this is not related to the latest PR, it happens in the old version as well).