XLB icon indicating copy to clipboard operation
XLB copied to clipboard

IndicesBoundaryMasker in JAX causes halting problem when used in multi-GPU

Open mehdiataei opened this issue 1 year ago • 3 comments

The current implementation occasionally causes halting between GPUs when padding is applied to the bmap. This issue primarily arises because the function cannot be JIT-compiled.

In the JAX implementation of IndicesBoundaryMasker, there are several operations, such as conditional statements, that are not supported in JIT-compiled JAX. The previous implementation was JIT-compatible and did not encounter these issues, which should serve as the reference for resolving this problem.

mehdiataei avatar Dec 03 '24 20:12 mehdiataei

can you create a repro please?

hsalehipour avatar Dec 03 '24 20:12 hsalehipour

Just run any example (e.g., flow over sphere) multiple times with having both GPUs visible. It will halt 90% of the time.

Please use the old implementation to fix this. Thanks.

mehdiataei avatar Dec 03 '24 20:12 mehdiataei

(this is not related to the latest PR, it happens in the old version as well).

mehdiataei avatar Dec 03 '24 20:12 mehdiataei