tvm
tvm copied to clipboard
[Bug] Segfault in TVM when building TIR module with pragma_unroll_explicit annotations
When attempting to build a TIR module containing pragma_unroll_explicit annotations with None values, TVM encounters a segmentation fault during the FlattenBuffer pass execution.
pragma_unroll_explicit=None should represent using the compiler's default unrolling strategy. That is, no forced unrolling, nor prohibited unrolling.
Actual behavior
!!!!!!! Segfault encountered !!!!!!!
File "/build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/x86_64/sigaction.c", line 0, in 0x00007f03f92df08f
File "<unknown>", line 0, in tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::FlattenBuffer(tvm::tir::PrimFunc)
File "<unknown>", line 0, in std::_Function_handler<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext), tvm::tir::transform::FlattenBuffer()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&, tvm::tir::PrimFunc&&, tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
File "/usr/local/src/conda/python-3.12.3/Include/internal/pycore_ceval.h", line 89, in _PyEval_EvalFrame
File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 1683, in _PyEval_Vector
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 419, in _PyFunction_Vectorcall
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 133, in _PyObject_FastCallDictTstate
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 508, in _PyObject_Call_Prepend
File "/usr/local/src/conda/python-3.12.3/Objects/typeobject.c", line 8770, in slot_tp_call
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
File "<unknown>", line 0, in tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo) const::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const
File "<unknown>", line 0, in std::_Function_handler<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext), tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo) const::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&, tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
File "/usr/local/src/conda/python-3.12.3/Include/internal/pycore_ceval.h", line 89, in _PyEval_EvalFrame
File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 1683, in _PyEval_Vector
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 419, in _PyFunction_Vectorcall
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 133, in _PyObject_FastCallDictTstate
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 508, in _PyObject_Call_Prepend
File "/usr/local/src/conda/python-3.12.3/Objects/typeobject.c", line 8770, in slot_tp_call
File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 578, in PyEval_EvalCode
File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1722, in run_eval_code_obj
File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1743, in run_mod
File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1643, in pyrun_file
File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 433, in _PyRun_SimpleFileObject
File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 78, in _PyRun_AnyFileObject
File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 360, in pymain_run_file_obj
File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 379, in pymain_run_file
File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 629, in pymain_run_python
File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 709, in Py_RunMain
File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 763, in Py_BytesMain
File "<unknown>", line 0, in 0xffffffffffffffff
Environment
tvm-latest(today)
Steps to reproduce
import tvm
tir_str = """# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(lhs: T.Buffer((4, 5, 6), "int16"), rhs: T.Buffer((1,), "int16"), T_add: T.Buffer((4, 5, 6), "int16")):
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": True})
# with T.block("root"):
for ax0 in T.serial(4, annotations={"pragma_unroll_explicit": None}):
for ax1 in T.serial(5):
for ax2 in T.serial(6):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(lhs[v_ax0, v_ax1, v_ax2], rhs[0])
T.writes(T_add[v_ax0, v_ax1, v_ax2])
T_add[v_ax0, v_ax1, v_ax2] = lhs[v_ax0, v_ax1, v_ax2] + rhs[0]
"""
tir_mod = tvm.script.from_source(tir_str)
tir_mod.show()
tvm.build(tir_mod)
Triage
- needs-triage
Don't use None for pragma_unroll_explicit. Use 1/0 or True/False instead.
https://github.com/apache/tvm/blob/8ab96af40e9f966f579436ca801e08d41b569516/src/tir/transforms/unroll_loop.cc#L111-L112