[ONNX] Exporter successfully export a ONNX model, but unexecutable on ONNXRUNTIME due to failed shape inference..
🐛 Describe the bug
In exporter, there is a onnx.shape_inference_infer_shapes checker before exporting, but it is not using strict_mode as onnxruntime does before executing the model, so the discrepancy makes some of the models exportable, but not executable.
A needed investigation of
- Is strict_mode necessary for onnxruntime?
- Difference between passing strict_mode or not
- What kind of model hitting this discrepancy
https://github.com/pytorch/pytorch/blob/ba84e9662e1c284805cce1652235eac1e4c56768/torch/csrc/jit/serialization/export.cpp#L1363-L1374
The following is a repro.
import torch
from torch import Tensor
import io
from typing import Optional
import onnx
import onnxruntime as ort
class LoopNoneInput(torch.nn.Module):
def forward(self, x):
y: Optional[Tensor] = None
for _ in range(x.size(0)):
y = x
return y
f = io.BytesIO()
x = torch.ones(1)
dynamic_axis_name = "condition"
torch.onnx.export(
torch.jit.script(LoopNoneInput()),
(x,),
f,
opset_version=16,
# Ensure condition is not constant
dynamic_axes={"x": {0: dynamic_axis_name}},
input_names=["x"],
)
model = onnx.load_model_from_string(f.getvalue())
onnx.checker.check_model(model)
ort_input = {
"x":x.cpu().numpy()
}
sess = ort.InferenceSession(model.SerializeToString() , ort.SessionOptions(), providers=['CPUExecutionProvider'])
out = sess.run(None, ort_inputs)
# Alternatively, one can use the following line to replace running ORT
# same outcome can be obtained.
#onnx.shape_inference.infer_shapes(model, strict_mode=True)
Versions
nightly pytorch nightly onnxruntime
Relate Issues
1. Is strict_mode necessary for onnxruntime?
onnxruntime doesn't run strict mode check, but it simply extends shape type inference results from ONNX. That's why this invalid model can't be executed after all (the model has serious type inference error, but the converter doesn't error out.).
2. Difference between passing strict_mode or not
When strict mode is on, any errors during ONNX shape type inference will be reported. Notice that custom op does not cause error in ONNX shape type inference as it's not a valid ONNX node.
3. What kind of model hitting this discrepancy
A model with failed shape type inference on valid ONNX node. In this case, the type inference is failed, because input is tensor type, while the output is optional type.
@AllenTiTaiWang a few questions that are not explained on the PR description
1. Is strict_mode necessary for onnxruntime?
onnxruntime doesn't run strict mode check, but it simply extends shape type inference results from ONNX. That's why this invalid model can't be executed after all (the model has serious type inference error, but the converter doesn't error out.).
a) What are the "serious" type inference error? list all, preferably with examples b) Is it an actual error (wrong information on the graph), missing required information or just missing optional information (such as rank, shape, dtype?) For the first two, we need to fix the export process itself, not just raise an error on shape infer c) Is this error generated by the torch onnx converter or after the "extended shape inference" on ort?
2. Difference between passing strict_mode or not
When strict mode is on, any errors during ONNX shape type inference will be reported. Notice that custom op does not cause error in ONNX shape type inference as it's not a valid ONNX node.
IMO, if any information is wrong in the exported ONNX graph, we should focus on fixing the error. After all, ONNX conversion failing due to shape inference is important for ONNX converter engineers, but a working ONNX model is what our users really want :)
3. What kind of model hitting this discrepancy
A model with failed shape type inference on valid ONNX node. In this case, the type inference is failed, because input is tensor type, while the output is optional type.
Elaborate with a python script, torch IR or ONNX IR as an example to keep record on the issue and ensure we are dealing with the right issue.
Overall, it seems this PR focus on flushing out an error to the user as opposed to fix it. Let's fix it FIRST. On the other hand, having a configurable full check configuration is useful, and maybe we can use _GLOBALS.check_shape_inference flag instead of letting it as a constant True in the middle of the code.
This way, ORT users can force the check themselves, if they wish to through something like
from torch.onnx import _globals
_globals.check_shape_inference = True
torch.onnx.export(model)
That will not change the default behavior and they will get what they need
@AllenTiTaiWang a few questions that are not explained on the PR description
a) What are the "serious" type inference error? list all, preferably with examples b) Is it an actual error (wrong information on the graph), missing required information or just missing optional information (such as rank, shape, dtype?) For the first two, we need to fix the export process itself, not just raise an error on shape infer c) Is this error generated by the torch onnx converter or after the "extended shape inference" on ort?
- The case I ran into here, after investigation, it was exporter issue, and if we have a
strict_modeshape type inference check during exporting, we could have spotted it in exporter, not wait until the model is run on onnxruntime. That's one of the reason we need a refactor ofcheck_onnx_proto, and this also fits our rule of spotting it earlier. - Specifically, according to @jcwchen, the current
onnx::shape_inference::InferShapes(model)only catches a super broken graph, and you can consideronnx::shape_inference::InferShapes(model, stric_mode=True)as a must in current situation. The case here, it's a type inference error which means that input and output type were mismatched, and it leads to a fail shape type inference error, but withoutstrict_mode, this was just ignored in exporter.
IMO, if any information is wrong in the exported ONNX graph, we should focus on fixing the error. After all, ONNX conversion failing due to shape inference is important for ONNX converter engineers, but a working ONNX model is what our users really want :)
Agree. However, some cases that you and @BowenBao have mentioned that some of our users are using Caffe2 running the models, and they don't need a "valid" onnx graph. Is this still the goal? If not, I do suggest that we make strict_mode mandatory, and error it out.
Elaborate with a python script, torch IR or ONNX IR as an example to keep record on the issue and ensure we are dealing with the right issue.
Overall, it seems this PR focus on flushing out an error to the user as opposed to fix it. Let's fix it FIRST. On the other hand, having a configurable full check configuration is useful, and maybe we can use
_GLOBALS.check_shape_inferenceflag instead of letting it as a constant True in the middle of the code.
using _GLOBALS.check_shape_inference is absolutely a great idea. I will check this out.
This way, ORT users can force the check themselves, if they wish to through something like
from torch.onnx import _globals _globals.check_shape_inference = True torch.onnx.export(model)That will not change the default behavior and they will get what they need
There are two topics here,
- Do we still take care of the users using invalid ONNX graph? I remember I was suggested to keep the current
check_onnx_protobehavior, as it doesn't affect the current user, and put a warning for them for current solution. If this suggestion doesn't stand anymore, I can totally let it error out in the related PR. - This related PR exposes shape type inference error as a warning message which we didn't have in exporter, and that helps us spot the potential error ahead of running on onnxruntime. The models pass
Strict_modecan finally considered as valid ONNX graph.