Add SpinQuant to generate.py
- Add SpinQuant to
torchao/_models/llama/generate.py - Only import SpinQuant when necessary in
eval.pyandgenerate.py(No need to import the large Hadamard matrices required for SpinQuant otherwise)
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1069
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 1543e4f9baf686a53d526b05780dfe14d0bfc999 with merge base e7b33bc91c831d10249c1222c8b4b667f18f28b7 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
thanks, any results we can show?
@jerryzh168 I'm fixing a torch.compile issue first related to the Hadamard transform using in SpinQuant, after that I'll post some benchmark results here. If you want, we can keep this PR open and I'll push the changes here.
SpinQuant now also works with torch.compile. Benchmark results (llama-2-7b, tested on an A100):
Baseline + torch.compile
Average tokens/sec: 114.08
Average Bandwidth: 1507.58 GB/s
Peak Memory Usage: 13.88 GB
Model Size: 13.21 GB
Spinquant (R4) + torch.compile
Average tokens/sec: 109.59
Average Bandwidth: 1448.61 GB/s
Peak Memory Usage: 13.72 GB
Model Size: 13.22 GB
Spinquant (R1+R2+R4) + torch.compile
NB: R1 and R2 are fused into the linear weights before inference takes place, so it is expected that they do not lead to additional overhead at inference time.
Average tokens/sec: 109.64
Average Bandwidth: 1449.28 GB/s
Peak Memory Usage: 14.90 GB
Model Size: 13.22 GB
Results without torch.compile:
Baseline
Average tokens/sec: 27.33
Average Bandwidth: 361.21 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB
Spinquant (R4)
Average tokens/sec: 23.01
Average Bandwidth: 304.10 GB/s
Peak Memory Usage: 14.24 GB
Model Size: 13.22 GB
SpinQuant now also works with torch.compile. Benchmark results (tested on an A100):
Baseline + torch.compile
Average tokens/sec: 114.31 Average Bandwidth: 1510.58 GB/s Peak Memory Usage: 13.88 GB Model Size: 13.21 GBSpinquant (R4) + torch.compile
Average tokens/sec: 109.00 Average Bandwidth: 1440.76 GB/s Peak Memory Usage: 13.98 GB Model Size: 13.22 GB
Thanks @tobiasvanderwerff, may I know which model you tested on, llama-2-7b?
Yep, llama-2-7b, I'll add that to the benchmark.
can you add benchmark numbers for R1+R2 as well? i think R4 is only for activation quantization
would be good to add this info into a readme file inside the spinquant dir
ready to merge?
Yep, this is ready @jerryzh168