[QST] EVT `Sqrt`
What is your question?
Trying to define an EVT using the Python interface:
def adam_epilogue(accum, one, beta1, beta2, eps, exp_avg, exp_avg2):
D = beta1 * exp_avg + (one - beta1) * accum
E = beta2 * exp_avg2 + (one - beta2) * accum * accum
denom = torch.sqrt(E) + eps
norm_grad = (D / denom)
return D, E, norm_grad
m, n = 4096, 128
beta1, beta2 = .9, .999
eps = 1e-8
exp_avg = torch.randn(m, n, dtype=type_C, device="cuda")
exp_avg2 = exp_avg ** 2
D = torch.empty_like(exp_avg)
E = torch.empty_like(exp_avg2)
norm_grad = torch.empty_like(exp_avg)
examples_tensors = {
"accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
"one": torch.tensor(1.0, dtype=torch.float32),
"beta1": beta1,
"beta2": beta2,
"eps": eps,
"exp_avg": exp_avg,
"exp_avg2": exp_avg2,
"D": D,
"E": E,
"norm_grad": norm_grad
}
epilogue_visitor = cutlass.epilogue.trace(adam_epilogue, examples_tensors)
Getting an error due to no mapping, likely for torch.sqrt (I tried to use E ** .5 also and got no mapping for Pow). Is there a way to use a sqrt in an EVT?
Note, I'm defining one as a tensor as a hack since using constant 1 twice gives an error during parsing about duplicate immutable.
Error log:
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/epilogue/epilogue.py:155, in trace(fn, example_tensors, **kwargs)
152 setattr(EpilogueFunctor, \"__call__\", staticmethod(fn))
154 epilogue_functor = EpilogueFunctor(**kwargs)
--> 155 epilogue_functor.trace(example_tensors)
156 return epilogue_functor
157 else:
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/frontend_base.py:108, in EVTFrontendBase.trace(self, *args, **kwargs)
106 def trace(self, *args, **kwargs):
107 # Parse the input
--> 108 self.parse(*args, **kwargs)
110 # Run the passes
111 self.pass_manager()
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:60, in PythonASTFrontend.parse(self, example_inputs)
58 self.source = textwrap.dedent(inspect.getsource(self.__call__))
59 self.ast = ast.parse(self.source)
---> 60 self.visit(self.ast)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File /usr/lib/python3.11/ast.py:426, in NodeVisitor.generic_visit(self, node)
424 for item in value:
425 if isinstance(item, AST):
--> 426 self.visit(item)
427 elif isinstance(value, AST):
428 self.visit(value)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:88, in PythonASTFrontend.visit_FunctionDef(self, node)
86 self.visit(arg)
87 for expr in node.body:
---> 88 self.visit(expr)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:136, in PythonASTFrontend.visit_Assign(self, node)
134 def visit_Assign(self, node: ast.BinOp):
135 target = self.visit(node.targets[0])
--> 136 value = self.visit(node.value)
137 # Create the assign node
138 self.add_store_node(target)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:123, in PythonASTFrontend.visit_BinOp(self, node)
121 if self.visiting_return:
122 raise SyntaxError(\"Return value cannot be an expression\")
--> 123 lhs = self.visit(node.left)
124 rhs = self.visit(node.right)
125 op = self.ast_op_to_bindings(type(node.op))
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
416 method = 'visit_' + node.__class__.__name__
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:164, in PythonASTFrontend.visit_Call(self, node)
162 name = self.add_layout_node(op, kwargs)
163 else:
--> 164 op = self.ast_op_to_bindings(func)
165 name = self.add_compute_node(op)
167 # Add edges
File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:77, in PythonASTFrontend.ast_op_to_bindings(op)
65 @staticmethod
66 def ast_op_to_bindings(op):
67 mapping = {
68 ast.Add: FunctionalOp.Plus,
69 ast.Sub: FunctionalOp.Minus,
(...)
75 \"max\": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
76 }
---> 77 return mapping[op]
KeyError: None"
}
@jackkosaian
Part of this is just what you mentioned: a missing mapping for sqrt. We would first need to add sqrt in include/cutlass/functional.h. However, we'd also need to add the mapping in cutlass.epilogue.trace.
@apuaaChen, can you comment on how this mapping would be added?
Traced through a couple runs of cutlass.epilogue.trace and think I understand the basic workflow of how python ops are mapped to respective cutlass functions.
Will try adding existing cutlass ops in functional.h not currently exposed in python to test and then sqrt which will require implementation on both cutlass and python side.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.