QAT model drops accuracy after converting with torch.ao.quantization.convert
Hello everyone.
I am implementing QAT model yolov8 in 4bit mode for weight and 8bit for activation by setting quant_min, quant_max in config. The model when training and eval gives quite good results, however when I convert using torch.ao.quantization.convert method, the model gives very bad evaluation results. Does anyone know how to solve this problem?
cc @andrewor14
Hi @tranngocduvnvp, can you share your prepare and convert flow? Are you using the APIs documented here? https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended. torchao QAT is not expected to work with torch.ao.quantization.convert from the pytorch/pytorch repo.
Hi @andrewor14, thanks for your feedback!
I found the cause of the drop in accuracy when performing the convert function. It came from the fact that in the training loop, at the 3rd epoch I turned off the FakeQuantize feature of some layers while still enabling observer, which caused the scale value to change when converting the weight to int format.
But I have another question, my model when training still gives quite bad results, I use another repo about QAT model yolo 4bit using pip install brevitas library and it gives very good results. Can you please show me the reason for the decrease in accuracy when using torch.ao library? The quantization configuration code for my layers is as follows:
def config_quant(bit_width_act, bit_width_weight, asym=False): my_qconfig = QConfig( activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=-(2**(bit_width_act-1)), quant_max=2**(bit_width_act-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, averaging_constant=0.01 ) if asym == False else FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=0, quant_max=2**(bit_width_act)-1, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, averaging_constant=0.01 ), weight=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-(2**(bit_width_weight-1)), quant_max=2**(bit_width_weight-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, # averaging_constant=0.01 ) ) return my_qconfig
`
model.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
for name, module in model.named_modules():
print(name)
if name == "model.net.p1":
module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
elif name == "model.net.p2.0":
module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
elif name == "model.net.p2.1":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p3.0":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p3.1":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p4.0":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p4.1":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p5.0":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p5.1":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.net.p5.2":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
elif name == "model.fpn.h1":
module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
`
Thank you so much !!