sparseml icon indicating copy to clipboard operation
sparseml copied to clipboard

Per-token dynamic quantization

Open anmarques opened this issue 2 years ago • 0 comments

This PR adds support for per-token dynamic quantization. Quantization scales and zero points are computed "on-the-fly" for each new tensor. Each token has its own quantization scale and zero-point (one value per token). The PR is motivated by the difficulties in quantizing activations for some LLMs, such as Llama2.

I attempted to keep in the same line as other quantization schemes supported by PyTorch. Hence I created a new observer (PerTokenDynamicObserver) and a new fake-quantization module (DynamicFakeQuantize) which inherit from ObserverBase and FakeQuantizeBase, respectively.

A new observer is needed because no existing observer supports a per-token computation of quantization scales (only per-tensor or per-channel). The new fake-quantization is also needed in order to use to execute per-token fake-quantization.

anmarques avatar Oct 11 '23 20:10 anmarques