TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

torch.compile graph breaks at `forward`

Open irhum opened this issue 1 year ago • 0 comments

When using torch.compile, we observe the following graph breaks at all TransformerEngine components. This appears to lead to a large number of lookups by TorchDynamo for each subgraph, resulting in a net slowdown:

  Break Reason 148:
    Reason: inline in skipfiles: Linear.forward  | _fn /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
    User Stack:
      <FrameSummary file /mnt/main0/home/ishafkat/testenv/projects/esm3/oss/attention.py, line 67 in <resume in forward>>
      <FrameSummary file /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
  Break Reason 149:
    Reason: inline in skipfiles: LayerNormMLP.forward  | _fn /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
    User Stack:
      <FrameSummary file /mnt/main0/home/ishafkat/testenv/projects/esm3/oss/blocks.py, line 32 in <resume in forward>>
      <FrameSummary file /mnt/main0/home/ishafkat/micromamba/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>

Is this expected behavior? This is on torch.__version__=2.1.2, and TransformerEngine at 1.1.0+cf6fc898

irhum avatar Mar 08 '24 21:03 irhum