lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Add input information to fusion definitions for trace inspection and debugging

Open riccardofelluga opened this issue 1 year ago • 0 comments

Before submitting
  • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [x ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure to update the docs?
  • [ ] Did you write any new necessary tests?

What does this PR do?

Fixes #387. This PR adds information about the inputs for a fusion definition such that it can be retrieved by inspecting the trace. A tutorial on how to read this information will be published as part of #205.

Also this PR is in preparation for #205

Quickly, from a trace:

trace = thunder.last_traces(fn)[-1]
trace_ctx = trace.python_ctx()
print(trace_ctx['nvFusion0'].last_inputs())

will print something like:

inputs = [
    torch.randn((2048,), dtype=torch.float32, device='cuda:0').as_strided((1, 2048), (2048, 1)),
    torch.randn((4096,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 4096), (4096, 0, 1))
]

I'm open to change the str output to returning a list of tensors, however for debugging it's usually enough to have a string.

riccardofelluga avatar May 08 '24 19:05 riccardofelluga