TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

AWQ-int4-quantization errors on Llama-2 13B based model with AMMO

Open Hongbosherlock opened this issue 2 years ago • 2 comments

version:

python3 -c "import tensorrt_llm; print(tensorrt_llm.__version__)"
0.7.1

nvidia-ammo~=0.5.0

I'm currently trying to use AMMO to quantize my model with awq_int4. My customed model is based on llama2-13B, but the attention and mlp layers of the model have bias, and the model is GQA structure. I have set attn_bias=True, mlp_bias=True here and the model runs correctly with fp16 precision.

I encountered an issue when using AMMO for quantization.

Cannot export model to the model_config. The AMMO optimized model state_dict (including the quantization factors) is saved to /target/model/quantized_int4-awq/ammo_model.0.pth using torch.save for further inspection.
Detailed export error: 
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ammo/torch/export/model_config_export.py", line 181, in export_model_config
    for model_config in torch_to_model_config(
  File "/usr/local/lib/python3.10/dist-packages/ammo/torch/export/model_config_export.py", line 114, in torch_to_model_config
    build_decoder_config(layer, model_metadata_config, decoder_type, dtype)
  File "/usr/local/lib/python3.10/dist-packages/ammo/torch/export/layer_utils.py", line 782, in build_decoder_config
    config.attention = build_attention_config(layer, model_metadata_config, dtype, config)
  File "/usr/local/lib/python3.10/dist-packages/ammo/torch/export/layer_utils.py", line 598, in build_attention_config
    config.qkv = build_qkv(qkv_modules, model_metadata_config, dtype, ext_config)
  File "/usr/local/lib/python3.10/dist-packages/ammo/torch/export/layer_utils.py", line 436, in build_qkv
    assert not (hasattr(m, "bias") and m.bias is not None)
AssertionError

I suspected it was related to bias, so I commented out the assert checks related to bias

#/usr/local/lib/python3.10/dist-packages/ammo/torch/export/layer_utils.py

        for m in qkv_modules:
            assert type(m) == nn.Linear
            # assert not (hasattr(m, "bias") and m.bias is not None)
         ...
        # assert (
        #     not qkv_modules[0].bias and not qkv_modules[1].bias and not qkv_modules[2].bias
        # ), "bias is not supported yet."

and added some code.

        q_bias = qkv_modules[0].bias
        k_bias = qkv_modules[1].bias
        v_bias = qkv_modules[2].bias

Since in my model "intermediate_size": 13632, so I can only use quantization with group_size=64. My configuration is as follows:

{'*weight_quantizer': {'num_bits': 4, 'block_sizes': {-1: 64}, 'enable': True}, '*input_quantizer': {'enable': False}, 'default': {'enable': False}}

The quantization process went relatively smoothly.

Hongbosherlock avatar Jan 15 '24 12:01 Hongbosherlock

After quantization, I built the model.

python build.py --model_dir /target/model/hf_model_v15 \
                --quant_ckpt_path /target/model/quantized_int4-awq/llama_tp1_rank0.npz \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --use_weight_only \
                --weight_only_precision int4_awq \
                --per_group \
                --output_dir /target/model/trt_engines/int4_AWQ/1-gpu/ \
                --use_rmsnorm_plugin float16 \
                --enable_context_fmha \
                --max_batch_size 20 \
                --group_size 64

then

python3 run.py --max_output_len=90 \
               --tokenizer_dir=/target/model/hf_model_v15 \
               --engine_dir=/target/model/trt_engines/int4_AWQ/1-gpu/

but the inference results I obtained were completely meaningless.

input_ids: [1, 26349, 523, 10700, 15, 41775, 5829, 44882, 538, 1374, 492, 13542, 626, 493]
Input [Text 0]: "Born in north-east France, Soyer trained as a"
output_ids: [572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572, 572]
Output [Text 0 Beam 0]: "B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B B"

I have successfully quantized my model based on AutoAWQ before, and the results looked promising.

Hongbosherlock avatar Jan 15 '24 12:01 Hongbosherlock

Hello, have you solved this issue? I also encountered the same issue.

Time-Limit avatar Feb 01 '24 03:02 Time-Limit

Hello, have you solved this issue? I also encountered the same issue.

I have solved this You should modify func load_from_awq_llama in the weight.py

Hongbosherlock avatar Feb 20 '24 03:02 Hongbosherlock