add axiswise scaling to Float8Linear
Summary:
This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e.
Future PR: support more granular configurability and optimize performance, add docs
Feel free to ignore the UX introduced in this PR, it's just an intermediate step. See next PR for the real UX.
Test Plan:
// tests pass
./test/float8/test_everything.sh
// sanity check on torchtitan with LLaMa 3 8B on 4 H100s with float8:
// 1. verify performance does not regress with tensorwise scaling
// 2. smoke test that axiswise scaling works and numerics are sane, performance isn't there though
// logs: https://gist.github.com/vkuzo/70fa5eb3c23375f307d11e7bae48682f
Reviewers:
Subscribers:
Tasks:
Tags:
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/920
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit d70326c4a3fe2900e482d0a2c1f251975bc4b781 with merge base 52d27a164d5f20e5095e8d35444c744c5a504f5d ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.