optimum icon indicating copy to clipboard operation
optimum copied to clipboard

Quantization with torch FX [part 2]

Open michaelbenayoun opened this issue 3 years ago • 5 comments

What does this PR do?

This is the second PR in the process of adding PyTorch based quantization features to optimum. This PR continues what was started in the previous PR #216, by adding:

  • 3 quantization configuration classes: those classes are directly mapping what is being done in PyTorch with torch.ao.quantization.QConfig and the qconfig_dict, but they are serializable, and thus can be uploaded to the HuggingFace Hub.
    • QConfigUnit: this is the optimum equivalent of torch.ao.quantization.Observer and torch.ao.quantization.FakeQuantize, it specifies how quantization values will be calibrated.
    • QConfig: it directly maps to torch.ao.quantization.QConfig, and does nothing more, except that it uses QConfigUnits instead of torch.ao.quantization.Observers or torch.ao.quantization.FakeQuantizes.
    • QuantizationConfig: this is the class that specifies how a model should be quantized, this is the optimum equivalent of the qconfig_dict in PyTorch. This is this class that can be uploaded to the Hub.
  • 2 calibration methods: one for post training static quantization, and the other for quantization aware training.
  • The quantize function which takes a model as input, and returns its quantized counterpart.

These functions are still agnostic to the optimum quantization configuration API, and work directly with the PyTorch API. This enable users to use this if needed, and the plan is to compose all of theme together in a class called FXQuantizer or something like that in next PRs.

michaelbenayoun avatar Jul 19 '22 14:07 michaelbenayoun

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Will be really useful to try out graph modifications on top of quantization I think!

fxmarty avatar Jul 20 '22 14:07 fxmarty

It looks nice, more flexible than onnxruntime!

The following code

from transformers import AutoModelForImageClassification

from torch.ao.quantization import get_default_qconfig
from optimum.fx.quantization import quantize
import torch

model_name = "google/vit-base-patch16-224"

model = AutoModelForImageClassification.from_pretrained(model_name)

model.eval()
qconfig = get_default_qconfig('fbgemm')
qconfig_dict = {"": qconfig}

quantized_model = quantize(
    model,
    approach="dynamic",
    qconfig_dict=qconfig_dict,
    input_names=["pixel_values"],
)

raises a fair bunch of warnings:

/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(head_mask_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(labels_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(output_attentions_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(output_hidden_states_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(interpolate_pos_encoding_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/utils/fx.py:839: UserWarning: Could not compute metadata for call_function target <function _assert_is_none at 0x7fadb4f4be50>: No metadata was found for Proxy(return_dict_1)
  warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/ao/quantization/observer.py:176: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/ao/quantization/observer.py:1135: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  warnings.warn(
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/nn/quantized/_reference/modules/utils.py:25: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/torch/nn/quantized/_reference/modules/utils.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(weight_qparams["zero_point"], dtype=zero_point_dtype, device=device))

Is it expected?

fxmarty avatar Aug 08 '22 12:08 fxmarty

Also, I could not get the quantize() to work for static quantization on vision tasks, related to what appears to me to be a bug in datasets (but I could be doing it wrong): https://github.com/huggingface/datasets/issues/4802

fxmarty avatar Aug 08 '22 12:08 fxmarty

An other weird behavior I came upon:

from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from transformers.utils.fx import symbolic_trace

from torch.ao.quantization import get_default_qconfig
from optimum.fx.quantization import quantize
from datasets import load_dataset


model_name = "google/vit-base-patch16-224"

model = AutoModelForImageClassification.from_pretrained(model_name)

raw = load_dataset("imagenet-1k", split="validation")
raw = raw.select(range(100))

model.eval()
qconfig = get_default_qconfig('fbgemm')
qconfig_dict = {"": qconfig}

"""
fx_model = quantize(
    model,
    approach="dynamic",
    qconfig_dict=qconfig_dict,
    input_names=["pixel_values"],
)
"""
fx_model = symbolic_trace(model, input_names=["pixel_values"])
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)

for i in range(10):
    inp = preprocessor(raw[i]["image"])
    inp.convert_to_tensors("pt")

    res = fx_model(**inp)

    print(res["logits"].argmax())

I get with only doing symbolic_trace:

tensor(91)
tensor(268)
tensor(979)
tensor(218)
tensor(19)
tensor(658)
tensor(893)
tensor(808)
tensor(305)
tensor(416)

while using the commented code with quantize() gives:

tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)
tensor(415)

and the logits are always the same.

I use transformers dev version (47e1676255e5dd86b9541f734cd4f4bdcbb50f4a).

fxmarty avatar Aug 08 '22 15:08 fxmarty