mmrotates got batch_norm error.
Issue Type
Others
Source
source
MCT Version
2.0.0
OS Platform and Distribution
Ubuntu 18.04
Python version
3.10.8
Describe the issue
Hello.
I want to convert mmrotate in MCT.
When I called mct.ptq.pytorch_post_training_quantization,
I got the folloing runtime error,
RuntimeError: running_mean should contain 512 elements not 12
I wonder if it would be possible to tell me how to address this error ?
Expected behaviour
I expect conversion succeed.
Code to reproduce the issue
docker image: pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
pip install -U openmim
mim install mmengine
mim install "mmcv==2.0.0rc4"
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
git checkout refs/tags/v3.1.0
pip install -v -e .
cd ../
git clone https://github.com/open-mmlab/mmrotate.git
cd mmrotate
git checkout dev-1.x
pip install -v -e .
cd ../
git clone https://github.com/sony/model_optimization.git local_mct
cd local_mct
git checkout refs/tags/v2.0.0
pip install -r requirements.txt
apt update
apt install libgl1-mesa-glx
apt-get install -y libglib2.0-0 libsm6 libxrender1 libxext6
pip install ipdb
[version]
mmcv 2.0.0rc4 https://github.com/open-mmlab/mmcv
mmdet 3.1.0 /workspace/mmdetection
mmengine 0.10.3 https://github.com/open-mmlab/mmengine
mmrotate 1.0.0rc1 /workspace/mmrotate
cd /workspace/mmrotate
wget https://download.openmmlab.com/mmrotate/v1.0/rotated_rtmdet/rotated_rtmdet_tiny-3x-dota_ms/rotated_rtmdet_tiny-3x-dota_ms-f12286ff.pth
python demo/image_demo_mct.py demo/demo.jpg configs/rotated_rtmdet/rotated_rtmdet_tiny-3x-dota_ms.py rotated_rtmdet_tiny-3x-dota_ms-f12286ff.pth --out-file result.jpg
image_demo_mct.py derived from image_demo.py
I added the following code after [model = init_detector()]
model4quant = RTMDet4Quant(model)
cfg = model.cfg
jpg_list = glob.glob('./demo/*.jpg')
cfg = cfg.copy()
test_pipeline = get_test_pipeline_cfg(cfg)
test_pipeline = Compose(test_pipeline)
def get_representative_dataset(n_iter):
def representative_dataset():
for index in range(n_iter):
img = jpg_list[index]
data_ = dict(img_path=img, img_id=0)
data_ = test_pipeline(data_)
data_['inputs'] = [data_['inputs']]
data_['data_samples'] = [data_['data_samples']]
data = model.data_preprocessor(data_)
yield data['inputs']
return representative_dataset
quant_model, _ = mct.ptq.pytorch_post_training_quantization(model4quant, get_representative_dataset(2))
class RTMDet4Quant(torch.nn.Module):
def __init__(self, in_rtmdet, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_module("base", in_rtmdet)
def forward(self, x):
features = self.base.backbone(x)
features = self.base.neck(features)
results = self.base.bbox_head.forward(features)
return results
Log output
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 116, in run_node
result = super().run_node(n)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 243, in call_function
return target(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 256 elements not 12
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 116, in run_node
result = super().run_node(n)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 243, in call_function
return target(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 256 elements not 12
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/workspace/mmrotate/demo/image_demo_mct.py", line 118, in <module>
main(args)
File "/workspace/mmrotate/demo/image_demo_mct.py", line 91, in main
quant_model, _ = mct.ptq.pytorch_post_training_quantization(model4quant, get_representative_dataset(3))
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/ptq/pytorch/quantization_facade.py", line 106, in pytorch_post_training_quantization
tg, bit_widths_config, _ = core_runner(in_model=in_module,
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/runner.py", line 93, in core_runner
graph = graph_preparation_runner(in_model,
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/graph_prep_runner.py", line 67, in graph_preparation_runner
graph = read_model_to_graph(in_model,
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/graph_prep_runner.py", line 194, in read_model_to_graph
graph = fw_impl.model_reader(in_model,
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/pytorch_implementation.py", line 145, in model_reader
return model_reader(_module, representative_data_gen, self.to_numpy, self.to_tensor)
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 153, in model_reader
fx_model = fx_graph_module_generation(model, representative_data_gen, to_tensor)
File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 96, in fx_graph_module_generation
ShapeProp(symbolic_traced).propagate(*input_for_shape_infer)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 152, in propagate
return super().run(*args)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 130, in run
self.env[node] = self.run_node(node)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 119, in run_node
raise RuntimeError(
RuntimeError: ShapeProp error for: node=%batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm](args = (%base_backbone_stem_0_conv, %base_backbone_stem_0_bn_running_mean, %base_backbone_stem_0_bn_running_var), kwargs = {weight: %base_backbone_stem_0_bn_weight, bias: %base_backbone_stem_0_bn_bias, training: False, momentum: 0.1, eps: 1e-05}) with meta={}
While executing %batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm](args = (%base_backbone_stem_0_conv, %base_backbone_stem_0_bn_running_mean, %base_backbone_stem_0_bn_running_var), kwargs = {weight: %base_backbone_stem_0_bn_weight, bias: %base_backbone_stem_0_bn_bias, training: False, momentum: 0.1, eps: 1e-05})
Original traceback:
None
in image_demo_mct.py the following module is imported.
from mmdet.utils import get_test_pipeline_cfg from mmcv.transforms import Compose import sys sys.path.insert(0, "../local_mct") import model_compression_toolkit as mct import glob import torch
One of our experts will reach out asap, sorry for the delay due to holidays season here.
Hi @YoshikiKato0220 , Looking at your log output, it appears that the error originates from 'torch.fx', which is used by MCT for converting a PyTorch model into graph. Usually it requires a modification in the model in order to make it torch.fx compatible. However, could you verify that the input images provided to the model (representative dataset) have consistent shapes? Thanks Idan
Hi @Idan-BenAmi Thank you for your reply.
However, could you verify that the input images provided to the model (representative dataset) have consistent shapes? Yes. I verified the following
ipdb> data['inputs'].shape torch.Size([1, 3, 1024, 1024])
ipdb> test_pipeline Compose( LoadImageFromFile(ignore_empty=False, to_float32=False, color_type='color', imdecode_backend='cv2', backend_args=None) Resize(scale=(1024, 1024), scale_factor=None, keep_ratio=True, clip_object_border=True), backend=cv2), interpolation=bilinear) LoadAnnotations(with_bbox=True, with_label=True, with_mask=False, with_seg=False, poly2mask=True, imdecode_backend='cv2', backend_args=None) ConvertBoxType(box_type_mapping={'gt_bboxes': 'rbox'}) Pad(size=(1024, 1024), size_divisor=None, pad_to_square=False, pad_val={'img': (114, 114, 114)}), padding_mode=constant) PackDetInputs(meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) )
if I changed the code model4quant = RTMDet4Quant(model) ↓ model4quant = model
I got torh.fx error.
Original exception was: Traceback (most recent call last): File "/workspace/mmrotate/../local_mct/model_compression_toolkit/core/pytorch/reader/reader.py", line 90, in fx_graph_module_generation symbolic_traced = symbolic_trace(pytorch_model) File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1070, in symbolic_trace graph = tracer.trace(root, concrete_args) File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace (self.create_arg(fn(*args)),), File "/workspace/mmdetection/mmdet/models/detectors/base.py", line 91, in forward if mode == 'loss': File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 298, in bool return self.tracer.to_bool(self) File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 174, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
Hi @YoshikiKato0220, Thanks for your input. Please refer to this log output:
File "/workspace/mmdetection/mmdet/models/detectors/base.py", line 91, in forward if mode == 'loss':
This is a common problem "torch.fx" encounters when a "if" statement appears in the "forward" function. Those kind of pytorch models cannot be supported by MCT since torch.fx is applied on the first stage. We can suggest you, if possible, to modify the model and remove this "if" statement (as long as it is possible and keep the model functionality for your needs). Idan
Hi @Idan-BenAmi Thanks for your comment.
I recognize "if" and "range" are relevant to "torch.fx" restrictions. In these case, I can understand how to address from log messages.
eg.) "if mode == 'loss':" To avoid "torch.fx" restrictions, I modified the following
model4quant = model ↓ model4quant = RTMDet4Quant(model)
RTMDet4Quant is based on SDD4Quant. (https://github.com/sony/model_optimization/blob/v2.0.0/tutorials/notebooks/pytorch/ptq/example_pytorch_ssdlite_mobilenetv3.ipynb)
But I'm not sure how to address this case from log messages. Could you give me any advise?
Hi @YoshikiKato0220 , Sorry for the delay. I see that this is also a "torch.fx' error from your log message, but I can't understand exactly what is the problem in your model.
File "/opt/conda/lib/python3.10/site-packages/torch/fx/passes/shape_prop.py", line 119, in run_node raise RuntimeError( RuntimeError: ShapeProp error for: node=%batch_norm : [#users=1] = call_function[target=torch.nn.functional.batch_norm]...
Maybe it's the additional "base" module that you added? You may find a hint here: https://github.com/pytorch/pytorch/issues/99629, which implies it might be related to incorrect dtype you're using for the input data, but I really can't say.
Hope it helps, I'll close this issue but feel free to open a new issue if you feel the problem is related to MCT. Thanks Idan