tvm
tvm copied to clipboard
[Relax] Fix the parser to avoid treating a list as an integer
When the output tensor is in the form of (10), TVM will crash unexpectedly as follows. This PR adds a rule to convert the (10) to (10,) to avoid such a crash!
Traceback (most recent call last):
File "/share_container/optfuzz/res/ut_ut_test/res_executions/14832_test.py", line 9, in <module>
class Module:
File "/share_container/optfuzz/res/ut_ut_test/res_executions/14832_test.py", line 11, in Module
def main(x: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), dtype="float32"))) -> R.Tensor((10), dtype="float32"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/script/parser/relax/entry.py", line 266, in Tensor
if shape is not None and not isinstance(shape, Var) and len(shape) == 0:
^^^^^^^^^^
TypeError: object of type 'int' has no len()
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10), dtype="float32"))) -> R.Tensor((10), dtype="float32"):
cls = Module
with R.dataflow():
lv: R.Tensor((10,), dtype="float32") = x[0]
R.output(lv)
return lv
mod = Module
mod.show()
cc @tqchen @Hzfengsy @Lunderberg @yongwww
Thanks for the contribution!
First, it's not a typical bug. (10) is just an integer rather than a list based on the Python syntax, while [10] and (10, ) are list and tuples. Here, we request a list as input, so the error is expected (based on the current implementation)
Second, it should be a valuable sugar for end users. But please:
- consider supporting not only integers but also prim expr, symbolic shapes etc.
- please add tests in the parser test file