pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

[ONNX] Exporter successfully export a ONNX model, but unexecutable on ONNXRUNTIME due to failed shape inference..

Open titaiwangms opened this issue 3 years ago • 1 comments

🐛 Describe the bug

In exporter, there is a onnx.shape_inference_infer_shapes checker before exporting, but it is not using strict_mode as onnxruntime does before executing the model, so the discrepancy makes some of the models exportable, but not executable.

A needed investigation of

  1. Is strict_mode necessary for onnxruntime?
  2. Difference between passing strict_mode or not
  3. What kind of model hitting this discrepancy

https://github.com/pytorch/pytorch/blob/ba84e9662e1c284805cce1652235eac1e4c56768/torch/csrc/jit/serialization/export.cpp#L1363-L1374

The following is a repro.

import torch
from torch import Tensor
import io
from typing import Optional
import onnx
import onnxruntime as ort

class LoopNoneInput(torch.nn.Module):
    def forward(self, x):
        y: Optional[Tensor] = None
        for _ in range(x.size(0)):
            y = x
        return y


f = io.BytesIO()
x = torch.ones(1)
dynamic_axis_name = "condition"
torch.onnx.export(
    torch.jit.script(LoopNoneInput()),
    (x,),
    f,
    opset_version=16,
    # Ensure condition is not constant
    dynamic_axes={"x": {0: dynamic_axis_name}},
    input_names=["x"],
)

model = onnx.load_model_from_string(f.getvalue())
onnx.checker.check_model(model)

ort_input = {
    "x":x.cpu().numpy()
}
sess = ort.InferenceSession(model.SerializeToString() , ort.SessionOptions(), providers=['CPUExecutionProvider'])
out = sess.run(None, ort_inputs)
# Alternatively, one can use the following line to replace running ORT
# same outcome can be obtained.
#onnx.shape_inference.infer_shapes(model, strict_mode=True)

Versions

nightly pytorch nightly onnxruntime

Relate Issues

Input/Output type mismatch

titaiwangms avatar Aug 01 '22 18:08 titaiwangms

1. Is strict_mode necessary for onnxruntime?

onnxruntime doesn't run strict mode check, but it simply extends shape type inference results from ONNX. That's why this invalid model can't be executed after all (the model has serious type inference error, but the converter doesn't error out.).

2. Difference between passing strict_mode or not

When strict mode is on, any errors during ONNX shape type inference will be reported. Notice that custom op does not cause error in ONNX shape type inference as it's not a valid ONNX node.

3. What kind of model hitting this discrepancy

A model with failed shape type inference on valid ONNX node. In this case, the type inference is failed, because input is tensor type, while the output is optional type.

titaiwangms avatar Aug 09 '22 22:08 titaiwangms

@AllenTiTaiWang a few questions that are not explained on the PR description

1. Is strict_mode necessary for onnxruntime?

onnxruntime doesn't run strict mode check, but it simply extends shape type inference results from ONNX. That's why this invalid model can't be executed after all (the model has serious type inference error, but the converter doesn't error out.).

a) What are the "serious" type inference error? list all, preferably with examples b) Is it an actual error (wrong information on the graph), missing required information or just missing optional information (such as rank, shape, dtype?) For the first two, we need to fix the export process itself, not just raise an error on shape infer c) Is this error generated by the torch onnx converter or after the "extended shape inference" on ort?

2. Difference between passing strict_mode or not

When strict mode is on, any errors during ONNX shape type inference will be reported. Notice that custom op does not cause error in ONNX shape type inference as it's not a valid ONNX node.

IMO, if any information is wrong in the exported ONNX graph, we should focus on fixing the error. After all, ONNX conversion failing due to shape inference is important for ONNX converter engineers, but a working ONNX model is what our users really want :)

3. What kind of model hitting this discrepancy

A model with failed shape type inference on valid ONNX node. In this case, the type inference is failed, because input is tensor type, while the output is optional type.

Elaborate with a python script, torch IR or ONNX IR as an example to keep record on the issue and ensure we are dealing with the right issue.

Overall, it seems this PR focus on flushing out an error to the user as opposed to fix it. Let's fix it FIRST. On the other hand, having a configurable full check configuration is useful, and maybe we can use _GLOBALS.check_shape_inference flag instead of letting it as a constant True in the middle of the code.

This way, ORT users can force the check themselves, if they wish to through something like

from torch.onnx import _globals
_globals.check_shape_inference = True

torch.onnx.export(model)

That will not change the default behavior and they will get what they need

thiagocrepaldi avatar Aug 25 '22 17:08 thiagocrepaldi

@AllenTiTaiWang a few questions that are not explained on the PR description

a) What are the "serious" type inference error? list all, preferably with examples b) Is it an actual error (wrong information on the graph), missing required information or just missing optional information (such as rank, shape, dtype?) For the first two, we need to fix the export process itself, not just raise an error on shape infer c) Is this error generated by the torch onnx converter or after the "extended shape inference" on ort?

  1. The case I ran into here, after investigation, it was exporter issue, and if we have a strict_mode shape type inference check during exporting, we could have spotted it in exporter, not wait until the model is run on onnxruntime. That's one of the reason we need a refactor of check_onnx_proto, and this also fits our rule of spotting it earlier.
  2. Specifically, according to @jcwchen, the current onnx::shape_inference::InferShapes(model) only catches a super broken graph, and you can consider onnx::shape_inference::InferShapes(model, stric_mode=True) as a must in current situation. The case here, it's a type inference error which means that input and output type were mismatched, and it leads to a fail shape type inference error, but without strict_mode, this was just ignored in exporter.

IMO, if any information is wrong in the exported ONNX graph, we should focus on fixing the error. After all, ONNX conversion failing due to shape inference is important for ONNX converter engineers, but a working ONNX model is what our users really want :)

Agree. However, some cases that you and @BowenBao have mentioned that some of our users are using Caffe2 running the models, and they don't need a "valid" onnx graph. Is this still the goal? If not, I do suggest that we make strict_mode mandatory, and error it out.

Elaborate with a python script, torch IR or ONNX IR as an example to keep record on the issue and ensure we are dealing with the right issue.

Overall, it seems this PR focus on flushing out an error to the user as opposed to fix it. Let's fix it FIRST. On the other hand, having a configurable full check configuration is useful, and maybe we can use _GLOBALS.check_shape_inference flag instead of letting it as a constant True in the middle of the code.

using _GLOBALS.check_shape_inference is absolutely a great idea. I will check this out.

This way, ORT users can force the check themselves, if they wish to through something like

from torch.onnx import _globals
_globals.check_shape_inference = True

torch.onnx.export(model)

That will not change the default behavior and they will get what they need

There are two topics here,

  1. Do we still take care of the users using invalid ONNX graph? I remember I was suggested to keep the current check_onnx_proto behavior, as it doesn't affect the current user, and put a warning for them for current solution. If this suggestion doesn't stand anymore, I can totally let it error out in the related PR.
  2. This related PR exposes shape type inference error as a warning message which we didn't have in exporter, and that helps us spot the potential error ahead of running on onnxruntime. The models pass Strict_mode can finally considered as valid ONNX graph.

titaiwangms avatar Aug 25 '22 18:08 titaiwangms