Vasiliy Kuznetsov
Vasiliy Kuznetsov
This is an example of implementing basic fp8 support with a Python tensor subclass. tl;dr; 1. FP8Tensor is the Python object which contains raw fp8 data (as torch.bits8), a scale,...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #126573 Summary: Uses the `seq_nr` field (introduced to aot_autograd nodes in https://github.com/pytorch/pytorch/pull/103129) to map the aot_autograd fx bw nodes to the corresponding...
Summary: This is a lightweight example of how a `Float8Tensor` could be build out of core, and how it could hook up with a scaling UEX based on module swapping....
Summary: Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te ``` Reviewers: Subscribers: Tasks: Tags:
Summary: For now, this is not for land and just saving work and starting a discussion. We need to calculate max(abs(tensor)) for each float8 gemm input when using per-tensor scaling....
Summary: This is a copy-paste of https://github.com/pytorch-labs/float8_experimental/pull/352 which never landed. Test Plan: ``` ``` Reviewers: Subscribers: Tasks: Tags:
Summary: This PR: support scaling of all arguments of all gemms to be axiswise, and ensure that training with axiswise scaling works e2e. Future PR: support more granular configurability and...
Summary: This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet....
This is a brain dump of what is missing from `torchao.float8` to support training with rowwise scaling, to help if someone wants to jump in to build this. ## already...
When AC is on for Float8Linear, what I would expect is: 1. the forward gemm is recomputed in the backward (it is not being recomputed now) 2. max(abs(activation)) and max(abs(weight))...