Is it possible to export a QAT model in AWQ Format?
I'm new to torchao and QAT but I'm pretty comfortable with PTQ techniques like AWQ and GPTQ. My deployment pipeline requires AWQ format (safetensors supported by autoawq or gptqmodel's new AWQ integration, needs to be in uint32 like Int4PackingFormat.PLAIN_INT32). I want to train a model with Int4WeightOnlyConfig and but it's confusing as to how I convert the final model into AWQ format, as AWQ format is supported but is this only for PTQ? Unless I'm missing something, you can save to roughly the same format (PLAIN_INT32 but only on xpu?) AND have AWQ support but there's no way to export to this format? If wrap my Int4WeightOnlyConfig in an AWQConfig, will it be trainable or only able to calibrate? Could I otherwise use something along the lines to the converter defined in this project?
cc @andrewor14
Hi @ambroser53, looks like today for QAT we only support Int4PackingFormat.PLAIN and PRESHUFFLED: https://github.com/pytorch/ao/blob/eb2a8fc064622dd83f502b5e1436989c548fbc47/torchao/quantization/qat/fake_quantize_config.py#L389-L392
But I don't see why it wouldn't work for PLAIN_INT32, maybe just need to add it to that list. Actually I think it may already work today if you just do the following, since PLAIN and PLAIN_INT32 should have similar (if not the same) PTQ numerics, just in a different layout. Also cc @jerryzh168 to confirm:
quantize_(
model,
QATConfig(Int4WeightOnlyConfig(), step="prepare"), # defaults to PLAIN
)
train(model)
quantize_(
model,
QATConfig(
Int4WeightOnlyConfig(int4_packing_format=Int4PackingFormat.PLAIN_INT32),
step="convert"
)
)
Hi thanks @andrewor14 this works, the initial issue is that PLAIN_INT32 crashes on an assert if it's not got device type "xpu", commenting that out makes it work fine with cuda. The real issue here now is that the AWQ format requires the zero_point tensors to be int32 packed int4s in the same way as the main tensors. right now torchao PLAIN_INT32 put the zero_point in int8. How would one go about doing that within Int4PlainInt32Tensor.from_hp?
Just to illustrate, here is the qzeros from my AWQ representation a layer in a model:
here are the same qzeros (zero_point) from the torchao version of the exact same layer in the same model:
I'm also assuming there's no way to save Int4PlainInt32Tensor with safe_serialization? It's a relatively simple fix if not but if there's something I'm missing it'll save me writing a converter.
save Int4PlainInt32Tensor with safe_serialization
we can support it, cc @liangel-02 is working on safetensor support currently
also wondering what is the motivation to use QAT and then convert to AWQ format, since I think QAT should be superior to AWQ and there is no need to convert to AWQ after QAT. also QAT is specific to the type of quantization you use in QAT as well.
You're right QAT is superior to AWQ, I don't want to actually use the AWQ algorithm for the quantisation as I have training data to make a much higher quality model through QAT, but exporting from pytorch into (non-executorch) edge deployment frameworks (such as ONNX) expects some antiquated quantisation formats which is the only way it will support them. It doesn't need any of the AWQ weight conversion systems, it just needs the style of format. As long as the zero_point is packed in int32 like the weights then it should work fine.
oh I see, makes sense, then seems that you could just try to convert the QAT converted model to AWQ with scale = 1 directly?
for "AWQ format requires the zero_point tensors to be int32 packed int4s in the same way as the main tensors" --> I think you can just do a manual conversion here to make sure the torchao int4 meets the requirement, would that work?
Torchao maintains scales in fp16 which I think is fine and matches the AWQ format perfectly to my understanding. Is there any reason to setting the scale to 1?
A manual conversion would work yes, but it would have to be a packing which I don't understand how to do, as just doing .to(torch.int32) just upscale the int4 tensors so the dimensions are all wrong. If you have any idea on how to do it that'd be very helpful.
Is there any reason to setting the scale to 1?
yeah it's because the model is from QAT, not from AWQ right? so setting scale to one (with the matching shape of weight/activation) will be the way to use AWQ format without actually doing the AWQ algorithm
but it would have to be a packing
I think you'll need to look into the code to see how packing is done, does this work? https://github.com/gau-nernst/gemma3-int4/blob/92517e8cac07f5caa3e3c98f26931b9046a0fa38/convert_flax.py#L258
yeah it's because the model is from QAT, not from AWQ right? so setting scale to one (with the matching shape of weight/activation) will be the way to use AWQ format without actually doing the AWQ algorithm
From my understanding, Int32packed int4s require scales anyway and torchao will automatically compute its own scales from it's optimised model when going from PLAIN (QAT) to PLAIN32 in the same format as AWQ takes it (float16 + same dimensions). I'm assuming that that conversion causes no loss of information or quality? In which case that should all be fine.
I think you'll need to look into the code to see how packing is done, does this work? https://github.com/gau-nernst/gemma3-int4/blob/92517e8cac07f5caa3e3c98f26931b9046a0fa38/convert_flax.py#L258
I'm happy to look into this but more specific would be helpful. This code doesn't work as it's expecting gemma's specific quantisation format which I believe wasn't torchao and definitely wasn't torchao.PLAIN32. Regardless it only packs the weights and sets zero_point to 8 for everything permanently (which I believe would negatively detriment the quality of the QAT model in conversion).
Unless there's some loss of quality doing what @andrewor14 recommended:
quantize_( model, QATConfig(Int4WeightOnlyConfig(), step="prepare"), # defaults to PLAIN ) train(model) quantize_( model, QATConfig( Int4WeightOnlyConfig(int4_packing_format=Int4PackingFormat.PLAIN_INT32), step="convert" ) )```
Then all the data is present and incorrect format when converted to plain32 it just needs the zero_point packing