we should ensure activation checkpointing with Float8Linear behaves optimally
When AC is on for Float8Linear, what I would expect is:
- the forward gemm is recomputed in the backward (it is not being recomputed now)
- max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)
Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require https://github.com/pytorch/ao/pull/892
bfloat16 linear fwd/bwd with activation checkpointing on
repro command
python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter bfloat16 --enable_activation_checkpointing True
trace snippet
we see 1 gemm in the forward and 3 in the backward, as expected
Float8Linear fwd/bwd with activation checkpointing on
repro command
python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter float8 --enable_activation_checkpointing True
trace snippet
issue 1: there are only two gemms in the backward instead of three issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))
the torch._scaled_mm behavior seems fine
the max(abs(tensor)) behavior seems inoptimal and we can do better with custom AC settings. I wrote up https://github.com/pytorch/torchtitan/pull/580 with initial findings, will follow up after the conferences this week with more.