decoder MMHA kernel support INT8 SCALE_Q_INSTEAD_OF_K and SCALE_P_INS…
Following the logic of MMHA_FP8_SCALE_Q_INSTEAD_OF_K and MMHA_FP8_SCALE_P_INSTEAD_OF_V, I implemented the INT8 version.
It is theoretically equivalent to the original compute logic without any numeric accuracy degradation.
I tested the speed on H20, A100, and 4090 GPUs. The results show that the average latency of the MMHA kernel decreased by about 5% to 8%, thanks to fewer FMUL instructions.
Nsight Compute Kernel Summary
Nsight Compute Instruction Statistics
@byshiue @Shixiaowei02 May you pls review and verify this PR~ Thank you very much!
@lishicheng1996 moving forward the INT8 scale is not enabled because we observed the accuracy drop because of that. Have you checked the accuracy after enabling that ? I will give it another try locally.
@lishicheng1996 moving forward the INT8 scale is not enabled because we observed the accuracy drop because of that. Have you checked the accuracy after enabling that ? I will give it another try locally.
Hi, thanks for your review~ I checked the MMLU accuracy score on Llama3.0-8B, with int8 smooth quant. With moving forward the INT8 scale, I didn't see accuracy drop~
@lishicheng1996 moving forward the INT8 scale is not enabled because we observed the accuracy drop because of that. Have you checked the accuracy after enabling that ? I will give it another try locally.
Hi, thanks for your review~ I checked the MMLU accuracy score on Llama3.0-8B, with int8 smooth quant. With moving forward the INT8 scale, I didn't see accuracy drop~
Thanks. I will run more tests for different models locally, and merge that into the internal TRT-LLM, and release it in coming weeks if everything looks good. Thanks again.
@lishicheng1996 moving forward the INT8 scale is not enabled because we observed the accuracy drop because of that. Have you checked the accuracy after enabling that ? I will give it another try locally.
Hi, thanks for your review~ I checked the MMLU accuracy score on Llama3.0-8B, with int8 smooth quant. With moving forward the INT8 scale, I didn't see accuracy drop~
Thanks. I will run more tests for different models locally, and merge that into the internal TRT-LLM, and release it in coming weeks if everything looks good. Thanks again.
Hi, may I ask how's the accuracy in your local test? ^_^
@PerkzZheng , If you still remember, could you please share any test results you may have from evaluating this optimization?
@lishicheng1996 , How would you like to proceed with this PR? Are you still interested in moving forward with this optimization? Based on previous discussions, @PerkzZheng observed accuracy drop from this approach before, which might be the reason this PR didn't get updates. If you have ideas on how to address the accuracy concerns while maintaining the performance benefits, please let us know.
I'll mark this as "waiting for feedback" so it can be automatically marked as stale if no feedback is received within 14 days. Note: Simply leaving any comment will prevent it from being marked as stale.
@PerkzZheng , If you still remember, could you please share any test results you may have from evaluating this optimization?
@lishicheng1996 , How would you like to proceed with this PR? Are you still interested in moving forward with this optimization? Based on previous discussions, @PerkzZheng observed accuracy drop from this approach before, which might be the reason this PR didn't get updates. If you have ideas on how to address the accuracy concerns while maintaining the performance benefits, please let us know.
I'll mark this as "waiting for feedback" so it can be automatically marked as stale if no feedback is received within 14 days. Note: Simply leaving any comment will prevent it from being marked as stale.
I have an internal gitlab MR that seems to work (I have fixed several issues), but there are some failing tests (long time ago). I think this is not on the critical path anymore (INT8 MMHA). we can ask someone else to resume the verification if needed.
@PerkzZheng, Thank you for the update!
Could you please confirm whether the failed tests were indeed due to numerical errors rather than actual accuracy issues? If the SCALE_QP_INSTEAD_OF_KV approach should be mathematically identical to the current implementation, then small numerical differences would be expected and acceptable.
If that's the case, I think this approach has good merit because:
- Reduces latency in INT8 KV cache dequantization
- Mathematically equivalent approach (just different computation order)
- No significant accuracy drop if differences are only numerical precision
Given the potential performance benefits, would it be possible for you to assign someone to help move this PR forward? The optimization could be valuable for users running INT8 KV cache configurations. But still I'm open to hear your opinion~ thank you.
@karljang
I remember it broke other tests, so the INT8 trick should work. It would be great if you can help re-assign this to someone else as I don't have BW recently. I will share you the MR I created before.
PR has not received an update in over 14 days. Adding stale label.
This PR was closed because it has been 14 days without activity since it has been marked as stale.