cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] EVT `Sqrt`

Open jeromeku opened this issue 1 year ago • 4 comments

What is your question? Trying to define an EVT using the Python interface:


def adam_epilogue(accum, one, beta1, beta2, eps, exp_avg, exp_avg2):
    D = beta1 * exp_avg + (one - beta1) * accum
    E = beta2 * exp_avg2 + (one - beta2) * accum * accum
    denom = torch.sqrt(E) + eps
    norm_grad = (D / denom)
    return D, E, norm_grad

m, n = 4096, 128
beta1, beta2 = .9, .999
eps = 1e-8

exp_avg = torch.randn(m, n, dtype=type_C, device="cuda")
exp_avg2 = exp_avg ** 2
D = torch.empty_like(exp_avg)
E = torch.empty_like(exp_avg2)
norm_grad = torch.empty_like(exp_avg)

examples_tensors = {
    "accum": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
    "one": torch.tensor(1.0, dtype=torch.float32),
    "beta1": beta1,
    "beta2": beta2,
    "eps": eps,
    "exp_avg": exp_avg,
    "exp_avg2": exp_avg2,
    "D": D,
    "E": E,
    "norm_grad": norm_grad
    }

epilogue_visitor = cutlass.epilogue.trace(adam_epilogue, examples_tensors)

Getting an error due to no mapping, likely for torch.sqrt (I tried to use E ** .5 also and got no mapping for Pow). Is there a way to use a sqrt in an EVT?

Note, I'm defining one as a tensor as a hack since using constant 1 twice gives an error during parsing about duplicate immutable.

Error log:

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/epilogue/epilogue.py:155, in trace(fn, example_tensors, **kwargs)
    152     setattr(EpilogueFunctor, \"__call__\", staticmethod(fn))
    154     epilogue_functor = EpilogueFunctor(**kwargs)
--> 155     epilogue_functor.trace(example_tensors)
    156     return epilogue_functor
    157 else:

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/frontend_base.py:108, in EVTFrontendBase.trace(self, *args, **kwargs)
    106 def trace(self, *args, **kwargs):
    107     # Parse the input
--> 108     self.parse(*args, **kwargs)
    110     # Run the passes
    111     self.pass_manager()

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:60, in PythonASTFrontend.parse(self, example_inputs)
     58 self.source = textwrap.dedent(inspect.getsource(self.__call__))
     59 self.ast = ast.parse(self.source)
---> 60 self.visit(self.ast)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    416 method = 'visit_' + node.__class__.__name__
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File /usr/lib/python3.11/ast.py:426, in NodeVisitor.generic_visit(self, node)
    424     for item in value:
    425         if isinstance(item, AST):
--> 426             self.visit(item)
    427 elif isinstance(value, AST):
    428     self.visit(value)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    416 method = 'visit_' + node.__class__.__name__
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:88, in PythonASTFrontend.visit_FunctionDef(self, node)
     86     self.visit(arg)
     87 for expr in node.body:
---> 88     self.visit(expr)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    416 method = 'visit_' + node.__class__.__name__
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:136, in PythonASTFrontend.visit_Assign(self, node)
    134 def visit_Assign(self, node: ast.BinOp):
    135     target = self.visit(node.targets[0])
--> 136     value = self.visit(node.value)
    137     # Create the assign node
    138     self.add_store_node(target)

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    416 method = 'visit_' + node.__class__.__name__
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:123, in PythonASTFrontend.visit_BinOp(self, node)
    121 if self.visiting_return:
    122     raise SyntaxError(\"Return value cannot be an expression\")
--> 123 lhs = self.visit(node.left)
    124 rhs = self.visit(node.right)
    125 op = self.ast_op_to_bindings(type(node.op))

File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
    416 method = 'visit_' + node.__class__.__name__
    417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:164, in PythonASTFrontend.visit_Call(self, node)
    162     name = self.add_layout_node(op, kwargs)
    163 else:
--> 164     op = self.ast_op_to_bindings(func)
    165     name = self.add_compute_node(op)
    167 # Add edges

File ~/Cpp/CUDA/Cutlass/cutlass-nightly/python/cutlass/backend/evt/frontend/python_ast.py:77, in PythonASTFrontend.ast_op_to_bindings(op)
     65 @staticmethod
     66 def ast_op_to_bindings(op):
     67     mapping = {
     68         ast.Add: FunctionalOp.Plus,
     69         ast.Sub: FunctionalOp.Minus,
   (...)
     75         \"max\": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
     76     }
---> 77     return mapping[op]

KeyError: None"
}

jeromeku avatar Mar 26 '24 17:03 jeromeku

@jackkosaian

thakkarV avatar Mar 26 '24 19:03 thakkarV

Part of this is just what you mentioned: a missing mapping for sqrt. We would first need to add sqrt in include/cutlass/functional.h. However, we'd also need to add the mapping in cutlass.epilogue.trace.

@apuaaChen, can you comment on how this mapping would be added?

jackkosaian avatar Mar 26 '24 20:03 jackkosaian

Traced through a couple runs of cutlass.epilogue.trace and think I understand the basic workflow of how python ops are mapped to respective cutlass functions.

Will try adding existing cutlass ops in functional.h not currently exposed in python to test and then sqrt which will require implementation on both cutlass and python side.

jeromeku avatar Mar 27 '24 21:03 jeromeku

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar May 03 '24 19:05 github-actions[bot]