ao icon indicating copy to clipboard operation
ao copied to clipboard

QAT model drops accuracy after converting with torch.ao.quantization.convert

Open tranngocduvnvp opened this issue 9 months ago • 3 comments

Hello everyone.

I am implementing QAT model yolov8 in 4bit mode for weight and 8bit for activation by setting quant_min, quant_max in config. The model when training and eval gives quite good results, however when I convert using torch.ao.quantization.convert method, the model gives very bad evaluation results. Does anyone know how to solve this problem?

tranngocduvnvp avatar Apr 28 '25 01:04 tranngocduvnvp

cc @andrewor14

supriyar avatar Apr 29 '25 16:04 supriyar

Hi @tranngocduvnvp, can you share your prepare and convert flow? Are you using the APIs documented here? https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended. torchao QAT is not expected to work with torch.ao.quantization.convert from the pytorch/pytorch repo.

andrewor14 avatar Apr 29 '25 19:04 andrewor14

Hi @andrewor14, thanks for your feedback!

I found the cause of the drop in accuracy when performing the convert function. It came from the fact that in the training loop, at the 3rd epoch I turned off the FakeQuantize feature of some layers while still enabling observer, which caused the scale value to change when converting the weight to int format.

But I have another question, my model when training still gives quite bad results, I use another repo about QAT model yolo 4bit using pip install brevitas library and it gives very good results. Can you please show me the reason for the decrease in accuracy when using torch.ao library? The quantization configuration code for my layers is as follows:

def config_quant(bit_width_act, bit_width_weight, asym=False): my_qconfig = QConfig( activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=-(2**(bit_width_act-1)), quant_max=2**(bit_width_act-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, averaging_constant=0.01 ) if asym == False else FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=0, quant_max=2**(bit_width_act)-1, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, averaging_constant=0.01 ), weight=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-(2**(bit_width_weight-1)), quant_max=2**(bit_width_weight-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, # averaging_constant=0.01 ) ) return my_qconfig

`
model.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)

for name, module in model.named_modules():
    print(name)

    if name == "model.net.p1":
        module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
    elif name == "model.net.p2.0":
        module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
    elif name == "model.net.p2.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
    elif name == "model.net.p3.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p3.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
    elif name == "model.net.p4.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p4.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.2":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.fpn.h1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    

`

Thank you so much !!

tranngocduvnvp avatar May 05 '25 06:05 tranngocduvnvp