ao icon indicating copy to clipboard operation
ao copied to clipboard

Add weight tensor-wise scaling for INT8 quantized and mixed-precision training

Open gau-nernst opened this issue 1 year ago • 1 comments

https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training

Currently INT8 training recipes only support row-wise scaling for weight. This should be strictly better than (or at least the same as) tensor-wise scaling for weight in terms of accuracy. However, this causes some issues in the backward pass, especially in FSDP2 if we want to support INT8 all-gather (cc https://github.com/pytorch/torchtitan/issues/578). Some pointers

  • For pre-training, INT8 tensor-wise scaling for weight "should" be ok. This is basically SwitchBack. BitNet uses 1.58-bit tensor-wise scaling and demonstrates good results.
  • For fine-tuning, it will be bad out-of-the-box (imagine INT8 tensor-wise scaling for PTQ). "Might" be ok after fine-tuning. Will need some experiments on this.

Opening this issue to welcome new contributors. Shouldn't be too difficult I think.

For context, to highlight the key difference between quantized training and mixed-precision training

  • INT8 quantized training: Only keeps INT8 weight, don't keep high precision weight. Don't quantize activations. Use stochastic rounding for weight update
  • INT8 mixed-precision training: Keep high precision weight. Dynamically quantize weights (and activations) to INT8 to use INT8 tensor cores.
    • For this new feature (INT8 tensor-wise scaling for weight), I think activations should still be row-wise scaling, since there doesn't seem to be any benefits to use tensor-wise scaling for activations.

gau-nernst avatar Oct 04 '24 02:10 gau-nernst

So basically for quantize_int8_rowwise we would pass in a quantization granularity that could either be set to row-wise or tensor-wise. In the case of tensor-wise, even though the scale is just one float, by making it a tensor it would be able to be broadcasted and the rest of the functions wouldn't really need to be changed (besides also adding the granularity param to from_float())

Seems easy to do, but was wondering if the change was more involved.

vayuda avatar Oct 09 '24 07:10 vayuda