Quantization with torch FX [part 2]
What does this PR do?
This is the second PR in the process of adding PyTorch based quantization features to optimum. This PR continues what was started in the previous PR #216, by adding:
- 3 quantization configuration classes: those classes are directly mapping what is being done in PyTorch with
torch.ao.quantization.QConfigand theqconfig_dict, but they are serializable, and thus can be uploaded to the HuggingFace Hub.-
QConfigUnit: this is theoptimumequivalent oftorch.ao.quantization.Observerandtorch.ao.quantization.FakeQuantize, it specifies how quantization values will be calibrated. -
QConfig: it directly maps totorch.ao.quantization.QConfig, and does nothing more, except that it usesQConfigUnits instead oftorch.ao.quantization.Observers ortorch.ao.quantization.FakeQuantizes. -
QuantizationConfig: this is the class that specifies how a model should be quantized, this is theoptimumequivalent of theqconfig_dictin PyTorch. This is this class that can be uploaded to the Hub.
-
- 2 calibration methods: one for post training static quantization, and the other for quantization aware training.
- The
quantizefunction which takes a model as input, and returns its quantized counterpart.
These functions are still agnostic to the optimum quantization configuration API, and work directly with the PyTorch API. This enable users to use this if needed, and the plan is to compose all of theme together in a class called FXQuantizer or something like that in next PRs.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Will be really useful to try out graph modifications on top of quantization I think!
It looks nice, more flexible than onnxruntime!
The following code
from transformers import AutoModelForImageClassification
from torch.ao.quantization import get_default_qconfig
from optimum.fx.quantization import quantize
import torch
model_name = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval()
qconfig = get_default_qconfig('fbgemm')
qconfig_dict = {"": qconfig}
quantized_model = quantize(
model,
approach="dynamic",
qconfig_dict=qconfig_dict,
input_names=["pixel_values"],
)
raises a fair bunch of warnings:
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(head_mask_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(labels_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(output_attentions_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(output_hidden_states_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(interpolate_pos_encoding_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(return_dict_1)
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/ao/quantization/observer.py:176: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/ao/quantization/observer.py:1135: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point
warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/nn/quantized/_reference/modules/utils.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/nn/quantized/_reference/modules/utils.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
torch.tensor(weight_qparams["zero_point"], dtype=zero_point_dtype, device=device))
Is it expected?
Also, I could not get the quantize() to work for static quantization on vision tasks, related to what appears to me to be a bug in datasets (but I could be doing it wrong): https://github.com/huggingface/datasets/issues/4802
An other weird behavior I came upon:
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from transformers.utils.fx import symbolic_trace
from torch.ao.quantization import get_default_qconfig
from optimum.fx.quantization import quantize
from datasets import load_dataset
model_name = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
raw = load_dataset("imagenet-1k", split="validation")
raw = raw.select(range(100))
model.eval()
qconfig = get_default_qconfig('fbgemm')
qconfig_dict = {"": qconfig}
"""
fx_model = quantize(
model,
approach="dynamic",
qconfig_dict=qconfig_dict,
input_names=["pixel_values"],
)
"""
fx_model = symbolic_trace(model, input_names=["pixel_values"])
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
for i in range(10):
inp = preprocessor(raw[i]["image"])
inp.convert_to_tensors("pt")
res = fx_model(**inp)
print(res["logits"].argmax())
I get with only doing symbolic_trace:
tensor(91)
tensor(268)
tensor(979)
tensor(218)
tensor(19)
tensor(658)
tensor(893)
tensor(808)
tensor(305)
tensor(416)
while using the commented code with quantize() gives:
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
and the logits are always the same.
I use transformers dev version (47e1676255e5dd86b9541f734cd4f4bdcbb50f4a).