float8 training axiswise scaling support with per-gemm-argument configuration
Summary:
This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:
output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp
Key characteristics of this recipe:
- increased accuracy for
grad_weight, which is important for real workloads -
outputandweightnow only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels
Here is how a user can configure this:
#
# short form
#
config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
#
# or, long form
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)
# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(keep_original_precision=True)
cc_go_gw = CastConfig(keep_original_precision=True)
# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)
config = Float8Config(
cast_config_input = cc_i,
cast_config_weight = cc_w,
cast_config_grad_output = cc_go,
cast_config_input_for_grad_weight = cc_i_gw,
cast_config_weight_for_grad_output = cc_w_go,
cast_config_grad_output_for_grad_weight = cc_go_gw,
gemm_config_output=gc_o,
gemm_config_grad_input=gc_gi,
gemm_config_grad_weight=gc_gw,
)
performance
Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.
gemm performance of torch._scaled_mm
baseline: tensorwise scaling
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
fast_accum name M K N ref_time_s fp8_time_s fp8_speedup
0 True 0 256 256 256 0.000004 0.000006 0.573115
1 True 1 512 512 512 0.000005 0.000007 0.659333
2 True 2 1024 1024 1024 0.000011 0.000010 1.080664
3 True 3 2048 2048 2048 0.000028 0.000017 1.596239
4 True 4 4096 4096 4096 0.000210 0.000082 2.551705
5 True 5 8192 8192 8192 0.001671 0.000680 2.457972
6 True 6 16384 16384 16384 0.015030 0.006498 2.313032
7 True 7 32768 32768 32768 0.103236 0.048097 2.146411
8 False 0 256 256 256 0.000004 0.000006 0.630061
9 False 1 512 512 512 0.000005 0.000007 0.767236
10 False 2 1024 1024 1024 0.000012 0.000008 1.391347
11 False 3 2048 2048 2048 0.000029 0.000020 1.457922
12 False 4 4096 4096 4096 0.000211 0.000101 2.100081
13 False 5 8192 8192 8192 0.001676 0.000788 2.128628
14 False 6 16384 16384 16384 0.014933 0.006351 2.351209
15 False 7 32768 32768 32768 0.103457 0.049498 2.090134
experiment: axiswise-scaling
> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise
fast_accum name M K N ref_time_s fp8_time_s fp8_speedup
0 True 0 256 256 256 0.000004 0.000004 0.966772
1 True 1 512 512 512 0.000005 0.000004 1.095791
2 True 2 1024 1024 1024 0.000011 0.000006 1.988363
3 True 3 2048 2048 2048 0.000027 0.000015 1.890065
4 True 4 4096 4096 4096 0.000210 0.000082 2.552356
5 True 5 8192 8192 8192 0.001674 0.001092 1.533132
6 True 6 16384 16384 16384 0.015114 0.008785 1.720480
7 True 7 32768 32768 32768 0.103286 0.071456 1.445439
8 False 0 256 256 256 0.000004 0.000004 0.899054
9 False 1 512 512 512 0.000005 0.000005 1.005340
10 False 2 1024 1024 1024 0.000011 0.000006 1.692868
11 False 3 2048 2048 2048 0.000028 0.000049 0.567655
12 False 4 4096 4096 4096 0.000210 0.000341 0.616193
13 False 5 8192 8192 8192 0.001678 0.002640 0.635541
14 False 6 16384 16384 16384 0.015051 0.021557 0.698212
15 False 7 32768 32768 32768 0.103497 0.169797 0.609533
performance on microbenchmark of ln -> linear -> sigmoid
Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.
> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv
fwd_M fwd_K fwd_N bf16_gemm_s fp8_gemm_s fp8_axs_gemm_time_s fp8_oh_dyn_limit ... fp8_del_s fp8_dyn_axs_s fp8_lw_s fp8_dyn_sp fp8_del_sp fp8_dyn_axs_sp fp8_lw_sp
0 256 256 256 0.000011 0.000018 0.000012 6.50457971014493e-6 ... 0.000043 0.000049 0.000030 0.465634 0.457907 0.398357 0.643088
1 512 512 512 0.000014 0.000020 0.000013 8.01831884057971e-6 ... 0.000047 0.000054 0.000034 0.489556 0.493467 0.432643 0.685842
2 1024 1024 1024 0.000033 0.000026 0.000017 1.40732753623188e-5 ... 0.000060 0.000063 0.000050 0.734123 0.741467 0.705941 0.891199
3 2048 2048 2048 0.000081 0.000055 0.000044 3.82931014492754e-5 ... 0.000147 0.000159 0.000142 0.815678 0.800811 0.739865 0.827441
4 4096 4096 4096 0.000632 0.000274 0.000247 0.000135172405797101 ... 0.000602 0.000622 0.000662 1.236320 1.261848 1.221755 1.147678
5 8192 8192 8192 0.005027 0.002216 0.003292 0.000522689623188406 ... 0.003665 0.004776 0.005720 1.432213 1.513035 1.161130 0.969448
6 16384 16384 16384 0.045113 0.018975 0.025706 0.00207275849275362 ... 0.024664 0.032254 0.038051 1.803456 1.883291 1.440118 1.220738
7 32768 32768 32768 0.312459 0.147255 0.214492 0.00827303397101449 ... 0.182645 0.240962 0.270973 1.696376 1.766307 1.338827 1.190552
performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:
- baseline (bf16 + compile): 6,294 wps
- f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
- f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
- LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)
so, looks like we have performance work to do with LW_AXISWISE_WITH_GW_HP in future PRs
accuracy
I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/940
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit b536435ec381717daa09ce5ed66a6981a4d03aa0 with merge base e76db70ec14ff1ff6fc9f1944c904d4247c05de9 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.