failed to build the serialized network due to the wrong shape inference of the LayerNormalization operator
Description
For the following onnx model,
it can be imported by the onnx frontend in TensorRT. However, it failes to build the serialized network. The following error message is produced:
[06/03/2025-21:13:02] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [shape.cpp:~op_constraints_msg_streamer_t:143]
Error during shape inference of
layer_norm_output : f32[2, 3, 1, 1] = move(node_of_layer_norm_output_normalizationBiased), name=node_of_layer_norm_output
Error is:
In op 'node_of_layer_norm_output(u_pw:move)', expected shape [2,3,1,3] but got [2,3,1,1]
[06/03/2025-21:13:02] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[node_of_sin_output...node_of_output]}.)
From the figure, we can see that the shape of the output of LayerNormalization is [2,3,1,1], which is also the result of onnx.shape_inference.infer_shapes(). However, the shape inferred by tensorrt for the output of LayerNormalization is [2,3,1,3].
Environment
TensorRT Version: 10.11.0.33
NVIDIA GPU: GeForce RTX 3080
NVIDIA Driver Version: 535.183.01
CUDA Version: 12.2
CUDNN Version: none
Operating System: ubuntu 20.04
Python Version (if applicable): 3.12.9
Steps To Reproduce
This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime.
from typing import Dict, List, Literal, Optional
import sys
import os
import numpy as np
import onnx
import onnxruntime
import tensorrt as trt
import argparse
import pickle
def test():
onnx_model = onnx.load('1111.onnx')
with open("inputs.pkl", "rb") as fp:
inputs = pickle.load(fp)
try:
ort_session = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except Exception as e:
print(e)
print("This model cannot be executed by onnxruntime!")
sys.exit(1)
print("ONNXRuntime:\n", ort_output)
#--------------------------------------------------------
trt_logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(trt_logger, '')
builder = trt.Builder(trt_logger)
network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, trt_logger)
with open('1111.onnx', 'rb') as model_file:
if not parser.parse(model_file.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
sys.exit(1)
config = builder.create_builder_config()
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine == None:
sys.exit(1)
if __name__ == "__main__":
test()
Commands or scripts:
Have you tried the latest release?: yes
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): the mode can be executed by onnxruntime.
Can you attach the ONNX file?
Can you attach the ONNX file?
@LeoZDong Thank you for your reply! The ONNX file is in the following attachment.
Issue has not received an update in over 14 days. Adding stale label. Please note the issue will be closed in 14 days after being marked stale if there is no update.
This issue was closed because it has been 14 days without activity since it has been marked as stale.