model_optimization icon indicating copy to clipboard operation
model_optimization copied to clipboard

mmrotates got batch_norm error.

Open YoshikiKato0220 opened this issue 1 year ago • 7 comments

Issue Type

Others

Source

source

MCT Version

2.0.0

OS Platform and Distribution

Ubuntu 18.04

Python version

3.10.8

Describe the issue

Hello.
I want to convert mmrotate in MCT.
When I called mct.ptq.pytorch_post_training_quantization,
I got the folloing runtime error,  
RuntimeError: running_mean should contain 512 elements not 12

I wonder if it would be possible to tell me how to address this error ?

Expected behaviour

I expect conversion succeed.

Code to reproduce the issue

docker image: pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel

pip install -U openmim
mim install mmengine
mim install "mmcv==2.0.0rc4"

git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
git checkout refs/tags/v3.1.0
pip install -v -e .
cd ../

git clone https://github.com/open-mmlab/mmrotate.git
cd mmrotate
git checkout dev-1.x
pip install -v -e .
cd ../

git clone https://github.com/sony/model_optimization.git local_mct
cd local_mct
git checkout refs/tags/v2.0.0
pip install -r requirements.txt

apt update
apt install libgl1-mesa-glx
apt-get install -y libglib2.0-0 libsm6 libxrender1 libxext6

pip install ipdb

[version]
mmcv       2.0.0rc4   https://github.com/open-mmlab/mmcv
mmdet      3.1.0      /workspace/mmdetection
mmengine   0.10.3     https://github.com/open-mmlab/mmengine
mmrotate   1.0.0rc1   /workspace/mmrotate

cd /workspace/mmrotate
wget https://download.openmmlab.com/mmrotate/v1.0/rotated_rtmdet/rotated_rtmdet_tiny-3x-dota_ms/rotated_rtmdet_tiny-3x-dota_ms-f12286ff.pth

python demo/image_demo_mct.py demo/demo.jpg configs/rotated_rtmdet/rotated_rtmdet_tiny-3x-dota_ms.py rotated_rtmdet_tiny-3x-dota_ms-f12286ff.pth --out-file result.jpg

image_demo_mct.py derived from image_demo.py

I added the following code after [model = init_detector()]

    model4quant = RTMDet4Quant(model)

    cfg = model.cfg
    jpg_list = glob.glob('./demo/*.jpg')
    cfg = cfg.copy()
    test_pipeline = get_test_pipeline_cfg(cfg)
    test_pipeline = Compose(test_pipeline)

    def get_representative_dataset(n_iter):
        def representative_dataset():
            for index in range(n_iter):
                img = jpg_list[index]
                data_ = dict(img_path=img, img_id=0)
                data_ = test_pipeline(data_)
                data_['inputs'] = [data_['inputs']]
                data_['data_samples'] = [data_['data_samples']]
                data = model.data_preprocessor(data_)
                yield data['inputs']

        return representative_dataset

    quant_model, _ = mct.ptq.pytorch_post_training_quantization(model4quant, get_representative_dataset(2))


class RTMDet4Quant(torch.nn.Module):
    def __init__(self, in_rtmdet, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.add_module("base", in_rtmdet)

    def forward(self, x):
        features = self.base.backbone(x)
        features = self.base.neck(features)
        results = self.base.bbox_head.forward(features)
        return results

Log output

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 116, in run_node
    result = super().run_node(n)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 243, in call_function
    return target(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
RuntimeError: running_mean should contain 256 elements not 12
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 116, in run_node
    result = super().run_node(n)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 243, in call_function
    return target(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
RuntimeError: running_mean should contain 256 elements not 12

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/workspace/mmrotate/demo/image_demo_mct.py", line 118, in <module>
    main(args)
  File "/workspace/mmrotate/demo/image_demo_mct.py", line 91, in main
    quant_model, _ = mct.ptq.pytorch_post_training_quantization(model4quant, get_representative_dataset(3))
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/ptq/pytorch/quantization_facade.py", line 106, in pytorch_post_training_quantization
    tg, bit_widths_config, _ = core_runner(in_model=in_module,
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/runner.py", line 93, in core_runner
    graph = graph_preparation_runner(in_model,
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/graph_prep_runner.py", line 67, in graph_preparation_runner
    graph = read_model_to_graph(in_model,
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/graph_prep_runner.py", line 194, in read_model_to_graph
    graph = fw_impl.model_reader(in_model,
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/pytorch_implementation.py", line 145, in model_reader
    return model_reader(_module, representative_data_gen, self.to_numpy, self.to_tensor)
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 153, in model_reader
    fx_model = fx_graph_module_generation(model, representative_data_gen, to_tensor)
  File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 96, in fx_graph_module_generation
    ShapeProp(symbolic_traced).propagate(*input_for_shape_infer)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 152, in propagate
    return super().run(*args)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 130, in run
    self.env[node] = self.run_node(node)
  File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 119, in run_node
    raise RuntimeError(
RuntimeError: ShapeProp error for: node=%batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm](args = (%base_backbone_stem_0_conv, %base_backbone_stem_0_bn_running_mean, %base_backbone_stem_0_bn_running_var), kwargs = {weight: %base_backbone_stem_0_bn_weight, bias: %base_backbone_stem_0_bn_bias, training: False, momentum: 0.1, eps: 1e-05}) with meta={}

While executing %batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm](args = (%base_backbone_stem_0_conv, %base_backbone_stem_0_bn_running_mean, %base_backbone_stem_0_bn_running_var), kwargs = {weight: %base_backbone_stem_0_bn_weight, bias: %base_backbone_stem_0_bn_bias, training: False, momentum: 0.1, eps: 1e-05})
Original traceback:
None

YoshikiKato0220 avatar Apr 25 '24 02:04 YoshikiKato0220

in image_demo_mct.py the following module is imported.

from mmdet.utils import get_test_pipeline_cfg from mmcv.transforms import Compose import sys sys.path.insert(0, "../local_mct") import model_compression_toolkit as mct import glob import torch

YoshikiKato0220 avatar Apr 25 '24 03:04 YoshikiKato0220

Hi YoshikiKato0220

One of our experts will reach out asap, sorry for the delay due to holidays season here.

ServiAmirPM avatar May 06 '24 10:05 ServiAmirPM

Hi @YoshikiKato0220 , Looking at your log output, it appears that the error originates from 'torch.fx', which is used by MCT for converting a PyTorch model into graph. Usually it requires a modification in the model in order to make it torch.fx compatible. However, could you verify that the input images provided to the model (representative dataset) have consistent shapes? Thanks Idan

Idan-BenAmi avatar May 06 '24 17:05 Idan-BenAmi

Hi @Idan-BenAmi Thank you for your reply.

However, could you verify that the input images provided to the model (representative dataset) have consistent shapes? Yes. I verified the following

ipdb> data['inputs'].shape torch.Size([1, 3, 1024, 1024])

ipdb> test_pipeline Compose( LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2', backend_args=None) Resize(scale=(1024, 1024), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2', backend_args=None) ConvertBoxType(box_type_mapping={'gt_bboxes': 'rbox'}) Pad(size=(1024, 1024), size_divisor=None, pad_to_square=False, pad_val={'img': (114, 114, 114)}), padding_mode=constant) PackDetInputs(meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) )

YoshikiKato0220 avatar May 07 '24 02:05 YoshikiKato0220

if I changed the code model4quant = RTMDet4Quant(model) ↓ model4quant = model

I got torh.fx error.

Original exception was: Traceback (most recent call last): File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 90, in fx_graph_module_generation symbolic_traced = symbolic_trace(pytorch_model) File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1070, in symbolic_trace graph = tracer.trace(root, concrete_args) File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace (self.create_arg(fn(*args)),), File "/workspace/mmdetection/mmdet/models/detectors/base.py", line 91, in forward if mode == 'loss': File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 298, in bool return self.tracer.to_bool(self) File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 174, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

YoshikiKato0220 avatar May 07 '24 02:05 YoshikiKato0220

Hi @YoshikiKato0220, Thanks for your input. Please refer to this log output:

File "/workspace/mmdetection/mmdet/models/detectors/base.py", line 91, in forward if mode == 'loss':

This is a common problem "torch.fx" encounters when a "if" statement appears in the "forward" function. Those kind of pytorch models cannot be supported by MCT since torch.fx is applied on the first stage. We can suggest you, if possible, to modify the model and remove this "if" statement (as long as it is possible and keep the model functionality for your needs). Idan

Idan-BenAmi avatar May 09 '24 09:05 Idan-BenAmi

Hi @Idan-BenAmi Thanks for your comment.

I recognize "if" and "range" are relevant to "torch.fx" restrictions. In these case, I can understand how to address from log messages.

eg.) "if mode == 'loss':" To avoid "torch.fx" restrictions, I modified the following

model4quant = model ↓ model4quant = RTMDet4Quant(model)

RTMDet4Quant is based on SDD4Quant. (https://github.com/sony/model_optimization/blob/v2.0.0/tutorials/notebooks/pytorch/ptq/example_pytorch_ssdlite_mobilenetv3.ipynb)

But I'm not sure how to address this case from log messages. Could you give me any advise?

YoshikiKato0220 avatar May 10 '24 05:05 YoshikiKato0220

Hi @YoshikiKato0220 , Sorry for the delay. I see that this is also a "torch.fx' error from your log message, but I can't understand exactly what is the problem in your model.

File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 119, in run_node raise RuntimeError( RuntimeError: ShapeProp error for: node=%batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm]...

Maybe it's the additional "base" module that you added? You may find a hint here: https://github.com/pytorch/pytorch/issues/99629, which implies it might be related to incorrect dtype you're using for the input data, but I really can't say.

Hope it helps, I'll close this issue but feel free to open a new issue if you feel the problem is related to MCT. Thanks Idan

Idan-BenAmi avatar May 27 '24 09:05 Idan-BenAmi