Static quantize self-attention module not work
Describe the issue
I am testing the inference performance of a model based on multi-head self attention. After I turn on static quantization, I find that the performance dropped instead. Then, I write a simple test and find that the self-attention graph is strange after static quantization.
Here is the simple reproducd code:
import math
import time
import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, quantize_static, CalibrationDataReader
from onnxruntime.quantization.quant_utils import QuantFormat
# multi head self attn by pytorch
class SelfAttn(nn.Module):
def __init__(self, hidden_size, num_attn_heads):
super().__init__()
attn_head_size = int(hidden_size / num_attn_heads)
all_head_size = num_attn_heads * attn_head_size
self.query = nn.Linear(hidden_size, all_head_size)
self.key = nn.Linear(hidden_size, all_head_size)
self.value = nn.Linear(hidden_size, all_head_size)
self.attn_head_size = attn_head_size
self.num_attn_heads = num_attn_heads
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attn_heads, self.attn_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attn_head_size)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # (batch, seq_len, num_heads, dim)
new_context_layer_shape = context_layer.size()[:-2] + (self.num_attn_heads * self.attn_head_size,)
context_layer = context_layer.view(new_context_layer_shape) # (batch, seq_len, hidden_size)
return context_layer
def export_model():
model = SelfAttn(128, 2)
x = torch.randn(2, 10, 128)
tuple_input = (x,)
torch.onnx.export(
model,
tuple_input,
f="attn.onnx",
input_names=['input_tensor'],
output_names=['output_tensor'],
dynamic_axes={'input_tensor': {0: 'batch_size', 1: 'seq_len'},
'output_tensor': {0: 'batch_size', 1: 'seq_len'}}
)
quantize_dynamic("attn.onnx", "attn_dynamic_quant.onnx")
class FakeDataReader(CalibrationDataReader):
def __init__(self, model_path):
session = onnxruntime.InferenceSession(model_path)
self.input_name = session.get_inputs()[0].name
fake_data = torch.randn(4, 1, 10, 128).numpy()
self.datasize = fake_data.shape[0]
self.fake_data = iter(
[{self.input_name: fake_data[i]} for i in range(self.datasize)]
)
def get_next(self):
return next(self.fake_data, None)
def export_static_quant_model():
dr = FakeDataReader("attn_preprocess.onnx")
quantize_static("attn_preprocess.onnx", "attn_static_quant.onnx", dr, quant_format=QuantFormat.QDQ)
def run(name, model_path, x):
ort_session = onnxruntime.InferenceSession(model_path)
ort_inputs = {ort_session.get_inputs()[0].name: x}
for _ in range(5):
outputs = ort_session.run(output_names=None, input_feed=ort_inputs)
t0 = time.time()
for _ in range(100):
outputs = ort_session.run(output_names=None, input_feed=ort_inputs)
t1 = time.time()
print(f"{name}: {t1 - t0}")
return outputs[0]
def performance():
x = torch.randn(4, 10, 128).numpy()
run("attn", "./attn.onnx", x)
run("attn_dynamic_quant", "./attn_dynamic_quant.onnx", x)
run("attn_static_quant", "./attn_static_quant.onnx", x)
if __name__ == "__main__":
export_model()
export_static_quant_model()
performance()
I write a simple self-attn module, and export to onnx model and dyanmic quant model. then I use onnxruntime tools just like:
python -m onnxruntime.quantization.preprocess --input attn.onnx --output attn_preprocess.onnx
Then I get static quant model. Finally run all models and get the inference time. static quant model takes the most time.
In my understanding, onnxruntime will optimize graph in session initialization stage. It will use function TransformGraph to optimize graph, including fusing DQD nodes. So I print the graph after opimization:
// onnxruntime/core/session/inference_session.cc:initilize()
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, saving_ort_format));
std::cout << "After transform:\n" << graph;
and some matmul nodes:
("/key/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output": tensor(float),"onnx::MatMul_91_DequantizeLinear_Output": tensor(float),) -> ("/key/MatMul_output_0": tensor(float),)
You can find the inputs of matmul node are all fp32 tensor, so I think it is fp32 gemm operation but not int8 gemm.
I have two questions:
- Am I quantize the model correctly?
- Why onnxruntime do not call int8 gemm?
Here is the whole graph:
After transform:
Inputs:
"input_tensor": tensor(float)
Nodes:
("input_tensor_QuantizeLinear", QuantizeLinear, "", 13) : ("input_tensor": tensor(float),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_QuantizeLinear_Output": tensor(int8),)
("key.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("key.bias_quantized": tensor(int8),"ortshared_1_0_1_13_token_177": tensor(float),"key.bias_zero_point": tensor(int8),) -> ("key.bias_DequantizeLinear_Output": tensor(float),)
("onnx::MatMul_90_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_90_quantized": tensor(int8),"ortshared_1_0_1_6_token_168": tensor(float),"onnx::MatMul_90_zero_point": tensor(int8),) -> ("onnx::MatMul_90_DequantizeLinear_Output": tensor(float),)
("onnx::MatMul_91_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_91_quantized": tensor(int8),"ortshared_1_0_1_8_token_170": tensor(float),"onnx::MatMul_91_zero_point": tensor(int8),) -> ("onnx::MatMul_91_DequantizeLinear_Output": tensor(float),)
("onnx::MatMul_92_DequantizeLinear", DequantizeLinear, "", 13) : ("onnx::MatMul_92_quantized": tensor(int8),"ortshared_1_0_1_1_token_163": tensor(float),"onnx::MatMul_92_zero_point": tensor(int8),) -> ("onnx::MatMul_92_DequantizeLinear_Output": tensor(float),)
("query.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("query.bias_quantized": tensor(int8),"ortshared_1_0_1_14_token_178": tensor(float),"query.bias_zero_point": tensor(int8),) -> ("query.bias_DequantizeLinear_Output": tensor(float),)
("value.bias_DequantizeLinear", DequantizeLinear, "", 13) : ("value.bias_quantized": tensor(int8),"ortshared_1_0_1_0_token_162": tensor(float),"value.bias_zero_point": tensor(int8),) -> ("value.bias_DequantizeLinear_Output": tensor(float),)
("input_tensor_DequantizeLinear", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output": tensor(float),)
("/key/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output": tensor(float),"onnx::MatMul_91_DequantizeLinear_Output": tensor(float),) -> ("/key/MatMul_output_0": tensor(float),)
("/query/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output/duplicated": tensor(float),"onnx::MatMul_90_DequantizeLinear_Output": tensor(float),) -> ("/query/MatMul_output_0": tensor(float),)
("/value/MatMul", MatMul, "", 13) : ("input_tensor_DequantizeLinear_Output/duplicated_token_0": tensor(float),"onnx::MatMul_92_DequantizeLinear_Output": tensor(float),) -> ("/value/MatMul_output_0": tensor(float),)
("/key/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/key/MatMul_output_0": tensor(float),"ortshared_1_0_1_10_token_173": tensor(float),"qdq_s8_to_u8_zp_conversion_token_188": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_189": tensor(uint8),)
("/query/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/query/MatMul_output_0": tensor(float),"ortshared_1_0_1_15_token_179": tensor(float),"qdq_s8_to_u8_zp_conversion": tensor(uint8),) -> ("qdq_s8_to_u8_quant": tensor(uint8),)
("/value/MatMul_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/value/MatMul_output_0": tensor(float),"ortshared_1_0_1_4_token_166": tensor(float),"qdq_s8_to_u8_zp_conversion_token_202": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_203": tensor(uint8),)
("/key/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_189": tensor(uint8),"ortshared_1_0_1_10_token_173": tensor(float),"qdq_s8_to_u8_zp_conversion_token_188": tensor(uint8),) -> ("/key/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
("/query/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant": tensor(uint8),"ortshared_1_0_1_15_token_179": tensor(float),"qdq_s8_to_u8_zp_conversion": tensor(uint8),) -> ("/query/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
("/value/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_203": tensor(uint8),"ortshared_1_0_1_4_token_166": tensor(float),"qdq_s8_to_u8_zp_conversion_token_202": tensor(uint8),) -> ("/value/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
("/key/Add", Add, "", 14) : ("key.bias_DequantizeLinear_Output": tensor(float),"/key/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/key/Add_output_0": tensor(float),)
("/query/Add", Add, "", 14) : ("query.bias_DequantizeLinear_Output": tensor(float),"/query/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/query/Add_output_0": tensor(float),)
("/value/Add", Add, "", 14) : ("value.bias_DequantizeLinear_Output": tensor(float),"/value/MatMul_output_0_DequantizeLinear_Output": tensor(float),) -> ("/value/Add_output_0": tensor(float),)
("/key/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/key/Add_output_0": tensor(float),"ortshared_1_0_1_16_token_180": tensor(float),"qdq_s8_to_u8_zp_conversion_token_190": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_191": tensor(uint8),)
("/query/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/query/Add_output_0": tensor(float),"ortshared_1_0_1_5_token_167": tensor(float),"qdq_s8_to_u8_zp_conversion_token_182": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_183": tensor(uint8),)
("/value/Add_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/value/Add_output_0": tensor(float),"ortshared_1_0_1_17_token_181": tensor(float),"qdq_s8_to_u8_zp_conversion_token_204": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_205": tensor(uint8),)
("/Reshape_1", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_191": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_193": tensor(uint8),)
("/Reshape", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_183": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_185": tensor(uint8),)
("/Reshape_2", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_205": tensor(uint8),"ortshared_7_1_4_0_token_171": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_207": tensor(uint8),)
("/Transpose_2", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_193": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_195": tensor(uint8),)
("/Transpose", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_185": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_187": tensor(uint8),)
("/Transpose_1", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_207": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_209": tensor(uint8),)
("/MatMul_output_0_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_197": tensor(uint8),"ortshared_1_0_1_7_token_169": tensor(float),"qdq_s8_to_u8_zp_conversion_token_196": tensor(uint8),) -> ("/MatMul_output_0_DequantizeLinear_Output": tensor(float),)
("/Div", Div, "", 14) : ("/MatMul_output_0_DequantizeLinear_Output": tensor(float),"ortshared_1_0_1_3_token_165": tensor(float),) -> ("/Div_output_0": tensor(float),)
("/Div_output_0_QuantizeLinear", QuantizeLinear, "", 13) : ("/Div_output_0": tensor(float),"ortshared_1_0_1_2_token_164": tensor(float),"qdq_s8_to_u8_zp_conversion_token_198": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_199": tensor(uint8),)
("/Transpose_3", Transpose, "", 13) : ("qdq_s8_to_u8_quant_token_211": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_213": tensor(uint8),)
("/Reshape_3", Reshape, "", 14) : ("qdq_s8_to_u8_quant_token_213": tensor(uint8),"ortshared_7_1_3_0_token_176": tensor(int64),) -> ("qdq_s8_to_u8_quant_token_215": tensor(uint8),)
("output_tensor_DequantizeLinear", DequantizeLinear, "", 13) : ("qdq_s8_to_u8_quant_token_215": tensor(uint8),"ortshared_1_0_1_11_token_174": tensor(float),"qdq_s8_to_u8_zp_conversion_token_214": tensor(uint8),) -> ("output_tensor": tensor(float),)
("input_tensor_DequantizeLinear/duplicated", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output/duplicated": tensor(float),)
("input_tensor_DequantizeLinear/duplicated_token_1", DequantizeLinear, "", 13) : ("input_tensor_QuantizeLinear_Output": tensor(int8),"ortshared_1_0_1_12_token_175": tensor(float),"input_tensor_zero_point": tensor(int8),) -> ("input_tensor_DequantizeLinear_Output/duplicated_token_0": tensor(float),)
("/MatMul", QLinearMatMul, "", 10) : ("qdq_s8_to_u8_quant_token_187": tensor(uint8),"ortshared_1_0_1_5_token_167": tensor(float),"qdq_s8_to_u8_zp_conversion_token_186": tensor(uint8),"qdq_s8_to_u8_quant_token_195": tensor(uint8),"ortshared_1_0_1_16_token_180": tensor(float),"qdq_s8_to_u8_zp_conversion_token_194": tensor(uint8),"ortshared_1_0_1_7_token_169": tensor(float),"qdq_s8_to_u8_zp_conversion_token_196": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_197": tensor(uint8),)
("/Softmax", QLinearSoftmax, "com.microsoft", 1) : ("qdq_s8_to_u8_quant_token_199": tensor(uint8),"ortshared_1_0_1_2_token_164": tensor(float),"qdq_s8_to_u8_zp_conversion_token_198": tensor(uint8),"ortshared_1_0_1_9_token_172": tensor(float),"qdq_s8_to_u8_zp_conversion_token_200": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_201": tensor(uint8),)
("/MatMul_1", QLinearMatMul, "", 10) : ("qdq_s8_to_u8_quant_token_201": tensor(uint8),"ortshared_1_0_1_9_token_172": tensor(float),"qdq_s8_to_u8_zp_conversion_token_200": tensor(uint8),"qdq_s8_to_u8_quant_token_209": tensor(uint8),"ortshared_1_0_1_17_token_181": tensor(float),"qdq_s8_to_u8_zp_conversion_token_208": tensor(uint8),"ortshared_1_0_1_11_token_174": tensor(float),"qdq_s8_to_u8_zp_conversion_token_210": tensor(uint8),) -> ("qdq_s8_to_u8_quant_token_211": tensor(uint8),)
Outputs:
"output_tensor": tensor(float)
To reproduce
Run the python code.
Urgency
No response
Platform
Linux
OS Version
ubuntu
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
main 61a79436e22892bdd91a905389f12e0aee68132e
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
We recommend to use dynamic quantization for transformer models on CPU. If you use static quant, you can limit the op_to_quantize to MatMul only. https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#method-selection https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#transformer-based-models
Thank you for your reply.
"recommend to use dynamic quantization for transformer models on CPU"
What is the reason for this? Is it because of the current support for transformer-based static quantization not good? Or considering the actual situation, the prediction result of dynamic quantification is better.
The reason why I think there is a problem with this static quantization is that the fully connected layer for calculating "query", "key" and "value" is not quantized correctly.
Any resolution to this task? Any exported model using self.attention is leaking this, hence no dynamic shapes cannot be used for seq_length! Thanks
Applying stale label due to no activity in 30 days
Applying stale label due to no activity in 30 days
Closing issue due to no activity in 30 days