Llama-2 13B SmoothQuant W8A8 Per-Tensor TP-4 performance is poor in v0.9.0 release
System Info
GPUs: A100, 4 GPUs (40 GB memory) Release: tensorrt-llm 0.9.0
Who can help?
@Tracin
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
- Install tensorrt-llm 0.9.0
- Create LLama-2 13B chat, TP-4, SmoothQuant 0.5, Per-Tensor checkpoint and engine
- Create LLama-2 13B chat, TP-4, SmoothQuant 0.5, Per-Channel + Per-Token checkpoint and engine
- Run mmlu.py
Expected behavior
Similar performance on MMLU between Per-Tensor and Per-Channel + Per-Token
actual behavior
-
LLama-2 13B chat, SmoothQuant 0.5, TP-4 Per-Channel + Per-Token Average accuracy - 54.52 STEM - 43.31 Humanities - 49.8 Social Science - 62.04 Misc - 60.24
-
LLama-2 13B chat, SmoothQuant 0.5, TP-4 Per-Tensor Average accuracy - 29.41 STEM - 29.56 Humanities - 25.65 Social Science - 28.31 Misc - 31.77
additional notes
n/a
Why do you expect the accuracy of Per-Tensor and Per-Channel + Per-Token are close? It is expected that Per-Channel + Per-Token has higher accuracy.
Is a 24% drop in MMLU 5-shot accuracy for Llama-2 13B expected?
It is hard to say it is expected or not because it is related to quantization workflow and model. But the Per-Channel + Per-Token is suggested and can keep the accuracy well. Could you explain why do you want to use Per-Tensor?
It is hard to say it is expected or not because it is related to quantization workflow and model. But the
Per-Channel + Per-Tokenis suggested and can keep the accuracy well. Could you explain why do you want to usePer-Tensor?
May I ask how the Per-Token is computed on the fly? Can you please point out where the code is?
Here is an example.
Here is an example.
As far as I know, per-token is generally used together with SmoothQuant. I noticed that the SmoothQuant plugin includes per-token-plugin. What is the relationship between the per-token plugin code here and the code you referred to?
thanks!
The code you refer is used to quantize the input tensor from higher precision to int8 before entering the SmoothQuant GEMM.