ao icon indicating copy to clipboard operation
ao copied to clipboard

Add SpinQuant to generate.py

Open tobiasvanderwerff opened this issue 1 year ago • 2 comments

  • Add SpinQuant to torchao/_models/llama/generate.py
  • Only import SpinQuant when necessary in eval.py and generate.py (No need to import the large Hadamard matrices required for SpinQuant otherwise)

tobiasvanderwerff avatar Oct 14 '24 17:10 tobiasvanderwerff

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1069

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 1543e4f9baf686a53d526b05780dfe14d0bfc999 with merge base e7b33bc91c831d10249c1222c8b4b667f18f28b7 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 14 '24 17:10 pytorch-bot[bot]

thanks, any results we can show?

jerryzh168 avatar Oct 14 '24 17:10 jerryzh168

@jerryzh168 I'm fixing a torch.compile issue first related to the Hadamard transform using in SpinQuant, after that I'll post some benchmark results here. If you want, we can keep this PR open and I'll push the changes here.

tobiasvanderwerff avatar Oct 15 '24 08:10 tobiasvanderwerff

SpinQuant now also works with torch.compile. Benchmark results (llama-2-7b, tested on an A100):

Baseline + torch.compile

Average tokens/sec: 114.08
Average Bandwidth: 1507.58 GB/s
Peak Memory Usage: 13.88 GB
Model Size: 13.21 GB

Spinquant (R4) + torch.compile

Average tokens/sec: 109.59
Average Bandwidth: 1448.61 GB/s
Peak Memory Usage: 13.72 GB
Model Size: 13.22 GB

Spinquant (R1+R2+R4) + torch.compile

NB: R1 and R2 are fused into the linear weights before inference takes place, so it is expected that they do not lead to additional overhead at inference time.

Average tokens/sec: 109.64
Average Bandwidth: 1449.28 GB/s
Peak Memory Usage: 14.90 GB
Model Size: 13.22 GB

tobiasvanderwerff avatar Oct 15 '24 08:10 tobiasvanderwerff

Results without torch.compile:

Baseline

Average tokens/sec: 27.33
Average Bandwidth: 361.21 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB

Spinquant (R4)

Average tokens/sec: 23.01
Average Bandwidth: 304.10 GB/s
Peak Memory Usage: 14.24 GB
Model Size: 13.22 GB

tobiasvanderwerff avatar Oct 15 '24 08:10 tobiasvanderwerff

SpinQuant now also works with torch.compile. Benchmark results (tested on an A100):

Baseline + torch.compile

Average tokens/sec: 114.31
Average Bandwidth: 1510.58 GB/s
Peak Memory Usage: 13.88 GB
Model Size: 13.21 GB

Spinquant (R4) + torch.compile

Average tokens/sec: 109.00
Average Bandwidth: 1440.76 GB/s
Peak Memory Usage: 13.98 GB
Model Size: 13.22 GB

Thanks @tobiasvanderwerff, may I know which model you tested on, llama-2-7b?

yiliu30 avatar Oct 15 '24 09:10 yiliu30

Yep, llama-2-7b, I'll add that to the benchmark.

tobiasvanderwerff avatar Oct 15 '24 09:10 tobiasvanderwerff

can you add benchmark numbers for R1+R2 as well? i think R4 is only for activation quantization

HDCharles avatar Oct 15 '24 15:10 HDCharles

would be good to add this info into a readme file inside the spinquant dir

HDCharles avatar Oct 17 '24 13:10 HDCharles

ready to merge?

jerryzh168 avatar Oct 21 '24 23:10 jerryzh168

Yep, this is ready @jerryzh168

tobiasvanderwerff avatar Oct 22 '24 06:10 tobiasvanderwerff