ao icon indicating copy to clipboard operation
ao copied to clipboard

float8 training axiswise scaling support with per-gemm-argument configuration

Open vkuzo opened this issue 1 year ago • 2 comments

Summary:

This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following:

output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
grad_weight_hp = input_t_hp @ grad_output_hp

Key characteristics of this recipe:

  1. increased accuracy for grad_weight, which is important for real workloads
  2. output and weight now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels

Here is how a user can configure this:

#
# short form
#

config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)

#
# or, long form
#

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE)

# grad_weight_hp = input_t_hp @ grad_output_hp
cc_i_gw = CastConfig(keep_original_precision=True)
cc_go_gw = CastConfig(keep_original_precision=True)

# ensure fast_accum is on to get fast kernels
gc_o = Float8GemmConfig(use_fast_accum=True)
gc_gi = Float8GemmConfig(use_fast_accum=True)
gc_gw = Float8GemmConfig(use_fast_accum=True)

config = Float8Config(
    cast_config_input = cc_i,
    cast_config_weight = cc_w,
    cast_config_grad_output = cc_go,
    cast_config_input_for_grad_weight = cc_i_gw,
    cast_config_weight_for_grad_output = cc_w_go,
    cast_config_grad_output_for_grad_weight = cc_go_gw,
    gemm_config_output=gc_o,
    gemm_config_grad_input=gc_gi,
    gemm_config_grad_weight=gc_gw,
)

performance

Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes.

gemm performance of torch._scaled_mm

baseline: tensorwise scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True
    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000006     0.573115
1         True     1    512    512    512    0.000005    0.000007     0.659333
2         True     2   1024   1024   1024    0.000011    0.000010     1.080664
3         True     3   2048   2048   2048    0.000028    0.000017     1.596239
4         True     4   4096   4096   4096    0.000210    0.000082     2.551705
5         True     5   8192   8192   8192    0.001671    0.000680     2.457972
6         True     6  16384  16384  16384    0.015030    0.006498     2.313032
7         True     7  32768  32768  32768    0.103236    0.048097     2.146411
8        False     0    256    256    256    0.000004    0.000006     0.630061
9        False     1    512    512    512    0.000005    0.000007     0.767236
10       False     2   1024   1024   1024    0.000012    0.000008     1.391347
11       False     3   2048   2048   2048    0.000029    0.000020     1.457922
12       False     4   4096   4096   4096    0.000211    0.000101     2.100081
13       False     5   8192   8192   8192    0.001676    0.000788     2.128628
14       False     6  16384  16384  16384    0.014933    0.006351     2.351209
15       False     7  32768  32768  32768    0.103457    0.049498     2.090134                

experiment: axiswise-scaling

> python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise

    fast_accum  name      M      K      N  ref_time_s  fp8_time_s  fp8_speedup
0         True     0    256    256    256    0.000004    0.000004     0.966772
1         True     1    512    512    512    0.000005    0.000004     1.095791
2         True     2   1024   1024   1024    0.000011    0.000006     1.988363
3         True     3   2048   2048   2048    0.000027    0.000015     1.890065
4         True     4   4096   4096   4096    0.000210    0.000082     2.552356
5         True     5   8192   8192   8192    0.001674    0.001092     1.533132
6         True     6  16384  16384  16384    0.015114    0.008785     1.720480
7         True     7  32768  32768  32768    0.103286    0.071456     1.445439
8        False     0    256    256    256    0.000004    0.000004     0.899054
9        False     1    512    512    512    0.000005    0.000005     1.005340
10       False     2   1024   1024   1024    0.000011    0.000006     1.692868
11       False     3   2048   2048   2048    0.000028    0.000049     0.567655
12       False     4   4096   4096   4096    0.000210    0.000341     0.616193
13       False     5   8192   8192   8192    0.001678    0.002640     0.635541
14       False     6  16384  16384  16384    0.015051    0.021557     0.698212
15       False     7  32768  32768  32768    0.103497    0.169797     0.609533

performance on microbenchmark of ln -> linear -> sigmoid

Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise.

> python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv

   fwd_M  fwd_K  fwd_N  bf16_gemm_s  fp8_gemm_s  fp8_axs_gemm_time_s      fp8_oh_dyn_limit  ... fp8_del_s fp8_dyn_axs_s  fp8_lw_s  fp8_dyn_sp  fp8_del_sp  fp8_dyn_axs_sp  fp8_lw_sp
0    256    256    256     0.000011    0.000018             0.000012   6.50457971014493e-6  ...  0.000043      0.000049  0.000030    0.465634    0.457907        0.398357   0.643088
1    512    512    512     0.000014    0.000020             0.000013   8.01831884057971e-6  ...  0.000047      0.000054  0.000034    0.489556    0.493467        0.432643   0.685842
2   1024   1024   1024     0.000033    0.000026             0.000017   1.40732753623188e-5  ...  0.000060      0.000063  0.000050    0.734123    0.741467        0.705941   0.891199
3   2048   2048   2048     0.000081    0.000055             0.000044   3.82931014492754e-5  ...  0.000147      0.000159  0.000142    0.815678    0.800811        0.739865   0.827441
4   4096   4096   4096     0.000632    0.000274             0.000247  0.000135172405797101  ...  0.000602      0.000622  0.000662    1.236320    1.261848        1.221755   1.147678
5   8192   8192   8192     0.005027    0.002216             0.003292  0.000522689623188406  ...  0.003665      0.004776  0.005720    1.432213    1.513035        1.161130   0.969448
6  16384  16384  16384     0.045113    0.018975             0.025706   0.00207275849275362  ...  0.024664      0.032254  0.038051    1.803456    1.883291        1.440118   1.220738
7  32768  32768  32768     0.312459    0.147255             0.214492   0.00827303397101449  ...  0.182645      0.240962  0.270973    1.696376    1.766307        1.338827   1.190552

performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only:

  • baseline (bf16 + compile): 6,294 wps
  • f8 all-tensorwise: 7,359 wps (1.17x vs baseline)
  • f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise)
  • LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline)

so, looks like we have performance work to do with LW_AXISWISE_WITH_GW_HP in future PRs

accuracy

I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work.

Screenshot 2024-10-04 at 10 05 24 PM

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo avatar Sep 24 '24 23:09 vkuzo

Stack from ghstack (oldest at bottom):

  • -> #940

vkuzo avatar Sep 24 '24 23:09 vkuzo

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/940

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit b536435ec381717daa09ce5ed66a6981a4d03aa0 with merge base e76db70ec14ff1ff6fc9f1944c904d4247c05de9 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Sep 24 '24 23:09 pytorch-bot[bot]