TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

TensorRT fails to infer the shape of the output for a valid onnx model.

Open coffezhou opened this issue 8 months ago • 2 comments

Description

For the following valid onnx model, Image TensorRT fails to infer the shape of the output. The shape of the final_output is:

(1, 0, 32, 7)

However, when I execute this model using onnxruntime, the shape of the final_output is:

(1, 1, 32, 7)

This issue further leads an error to gpu memory allocation.

 device_mem = cuda.mem_alloc(host_mem.nbytes)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
pycuda._driver.LogicError: cuMemAlloc failed: invalid argument

Environment

TensorRT Version: 10.11.0.33

NVIDIA GPU: GeForce RTX 3080

NVIDIA Driver Version: 535.183.01

CUDA Version: 12.2

CUDNN Version: none

Operating System: ubuntu 20.04

Python Version (if applicable): 3.12.9

Tensorflow Version (if applicable): none

PyTorch Version (if applicable): none

Baremetal or Container (if so, version): none

Steps To Reproduce

This bug can be reproduced by the following code with the model in the attachment. As shown in the code, the model can be executed by onnxruntime.

from typing import Dict, List, Literal, Optional
import sys
import os

import numpy as np
import onnx
import onnxruntime
from onnx import ModelProto, TensorProto, helper, mapping

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

import argparse
import pickle


def test():
    onnx_model = onnx.load("1111.onnx")
    
    with open("inputs.pkl", "rb") as fp:
        inputs = pickle.load(fp)

    try:
        ort_session = onnxruntime.InferenceSession(
            onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
        )
        ort_output = ort_session.run([], inputs)
    except Exception as e:
        print(e)
        print("This model cannot be executed by onnxruntime!")
        sys.exit(1)
    
    print("ONNXRuntime:\n", ort_output[0].shape)
    
    #--------------------------------------------------------
        
    trt_logger = trt.Logger(trt.Logger.WARNING)
    trt.init_libnvinfer_plugins(trt_logger, '')
    builder = trt.Builder(trt_logger)
    network = builder.create_network(flags=1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

    parser = trt.OnnxParser(network, trt_logger)
    with open("1111.onnx", 'rb') as model_file:
        if not parser.parse(model_file.read()):
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            sys.exit(1)
            
    config = builder.create_builder_config()
    serialized_engine = builder.build_serialized_network(network, config)
    
    if serialized_engine == None:
        sys.exit(1)
    
    with open("engine.trt", "wb") as f:
        f.write(serialized_engine)
        
    with open("engine.trt", "rb") as f, trt.Runtime(trt_logger) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    
    #------------------------------------------------------------
    for binding in engine:
        print(binding, engine.get_tensor_shape(binding))

    
if __name__ == "__main__":
    test()

testcast.zip

Commands or scripts:

Have you tried the latest release?: yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): the mode can be executed by onnxruntime.

coffezhou avatar May 29 '25 03:05 coffezhou

Can you please share the onnx model?

yuanyao-nv avatar Jun 04 '25 18:06 yuanyao-nv

Can you please share the onnx model?

Sure, the onnx model is in the attachment 'testcase.zip'.

coffezhou avatar Jun 05 '25 03:06 coffezhou

Thanks for sharing the model. I checked the nodes and it seems like the first difference occurs after MaxPool. ORT thinks it should output [1,32,7,1] but in fact it should be [1,32,7,0]. You can verify with this ONNX formula:

output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)

yuanyao-nv avatar Jun 30 '25 18:06 yuanyao-nv

Thanks for sharing the model. I checked the nodes and it seems like the first difference occurs after MaxPool. ORT thinks it should output [1,32,7,1] but in fact it should be [1,32,7,0]. You can verify with this ONNX formula:

output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1)

@yuanyao-nv Thanks! I have verified that the shape of the MaxPool's results should be [1,32,7,0] according to the formula. This may be the defect of onnxruntime. I try to find the reason in onnxruntime, but I am failed.

coffezhou avatar Jul 01 '25 06:07 coffezhou

Maybe try filing an issue on the onnxrutime repo to get help? I will close this issue for now. Feel free to reopen if more discussion is needed. Thanks.

yuanyao-nv avatar Jul 01 '25 17:07 yuanyao-nv