dace
dace copied to clipboard
MPI nodes don't handle scalar args correctly
Describe the bug
The dace program fails to compile if the integer args of MPI nodes (dst and tag) are anything other than symbols, symbol expressions and numbers.
To Reproduce
rank = dc.symbol('rank', dtype=dc.int64)
@dc.program
def func(A: dc.int32[N]):
# dace.comm.Send(A[0], rank - 1, 0) # Works
dace.comm.Send(A[0], abs(rank - 1), 0)
The program fails to compile with the following error:
ValueError: Node type "Send" not supported for promotion
Same behavior in other scenarios:
# ...
a = 0
a = rank
a = A[0]
dace.comm.Send(A[0], a, 0)
Desktop (please complete the following information):
- Latest DaCe master branch
Possible fix:
The code below should check if the sdfg.arrays entry for the corresponding arg is a Scalar when given an str and fall to the last branch.
https://github.com/spcl/dace/blob/f4b4d01f67cb089b3ef821673e0a12405c94f9b1/dace/frontend/common/distr.py#L424-L436