TransformerEngine
TransformerEngine copied to clipboard
torch.compile graph breaks at `forward`
When using torch.compile, we observe the following graph breaks at all TransformerEngine components. This appears to lead to a large number of lookups by TorchDynamo for each subgraph, resulting in a net slowdown:
Break Reason 148:
Reason: inline in skipfiles: Linear.forward | _fn /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
User Stack:
<FrameSummary file /mnt/main0/home/ishafkat/testenv/projects/esm3/oss/attention.py, line 67 in <resume in forward>>
<FrameSummary file /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
Break Reason 149:
Reason: inline in skipfiles: LayerNormMLP.forward | _fn /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
User Stack:
<FrameSummary file /mnt/main0/home/ishafkat/testenv/projects/esm3/oss/blocks.py, line 32 in <resume in forward>>
<FrameSummary file /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
Is this expected behavior? This is on torch.__version__=2.1.2, and TransformerEngine at 1.1.0+cf6fc898