cli_demo_quantization.py does not work with latest torchao (git)
System Info / 系統信息
torch 2.4.1 / diffuser 0.30.2 / Ubuntu 22.04.4 LTS / Cuda driver 12.6
Information / 问题信息
- [X] The official example scripts / 官方的示例脚本
- [ ] My own modified scripts / 我自己修改的脚本和任务
Reproduction / 复现过程
pip install git+https://github.com/pytorch/ao.git
python cli_demo_quantization.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --quantization_scheme fp8 --dtype bfloat16
Expected behavior / 期待表现
does not work with torchao from current git tree anymore
Traceback (most recent call last): File "/home/x/CogVideo/inference/cli_demo_quantization.py", line 26, in from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8 ImportError: cannot import name 'ActivationCasting' from 'torchao.float8.inference'
it seems it got (re)moved here: https://github.com/pytorch/ao/commit/848e123e37df7e7033f26619b02562525404c2b5
Is there any up-to-date example of how to do inference with quantization?
I'm not updated with all the changes in torchao at the moment but for now, I think it would be best to use the version of torchao before these modifications.
torchao version === 0.6.1 same problem
This should be re-opened so that someone can submit a PR to update it.
As an intern at Zhipu AI, I'd like to share a solution for the FP8 quantization issue that users have been encountering with CogVideoX:
When using newer versions of TorchAO, users are encountering an ImportError with the FP8 quantization feature. Here's the working solution I've verified:
Instead of using:
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
quantize_to_float8(part, QuantConfig(ActivationCasting.DYNAMIC))
You should use:
from torchao.quantization import quantize_, float8_weight_only
quantize_(model, float8_weight_only())
I've tested this solution with CogVideoX-5B model. This workaround addresses the API compatibility issue for users with newer versions of TorchAO.
The modification provides the same FP8 quantization functionality while being compatible with the updated TorchAO API structure. Please make sure you have the latest version of TorchAO installed.