lightning-thunder
lightning-thunder copied to clipboard
Add input information to fusion definitions for trace inspection and debugging
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [x ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
What does this PR do?
Fixes #387. This PR adds information about the inputs for a fusion definition such that it can be retrieved by inspecting the trace. A tutorial on how to read this information will be published as part of #205.
Also this PR is in preparation for #205
Quickly, from a trace:
trace = thunder.last_traces(fn)[-1]
trace_ctx = trace.python_ctx()
print(trace_ctx['nvFusion0'].last_inputs())
will print something like:
inputs = [
torch.randn((2048,), dtype=torch.float32, device='cuda:0').as_strided((1, 2048), (2048, 1)),
torch.randn((4096,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 4096), (4096, 0, 1))
]
I'm open to change the str output to returning a list of tensors, however for debugging it's usually enough to have a string.