Fuse reductions with MLIR with multi-outputs
This issue has two parts.
The first part is to fuse reductions(including split reductions) with MLIR, including any pointwise.
The second part is to use multiple outputs when fusing, to fuse even when its used multiple times. So we may have something like convoluttion -> add -> reduce but the add maybe used more than once so we will need to output that computation along with the output of reduce.
We should also handle multiple reductions as well.
The goal of this is to fuse layernorm with with two convs/gemms. In #3010, it will fuse whats after the reductions with the following convolution, but we still need to fuse the reduction step. In #3097, it will spit the reduction, but we need to then fuse it with mlir.
Here is an example from unet where we want to do the fusion:
p = migraphx.program()
mmain = p.get_main_module()
x_main_module_0 = mmain.add_literal(migraphx.create_argument(migraphx.shape(type="float_type", lens=[1]), [1e-4]))
x_sample = mmain.add_parameter("sample",migraphx.shape(type="float_type", lens=[2,4,64,64]))
x_main_module_4 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,1,1]), 1))
x_main_module_5 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,1,1]), 2))
x_main_module_6 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,320,3,3]), 3))
x_main_module_7 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320]), 4))
x_main_module_8 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,4,3,3]), 5))
x_main_module_9 = mmain.add_instruction(migraphx.op("convolution", padding=[1,1,1,1]), [x_sample, x_main_module_8])
x_main_module_10 = mmain.add_instruction(migraphx.op("broadcast", axis=1, out_lens=[2,320,64,64]), [x_main_module_7])
x_main_module_11 = mmain.add_instruction(migraphx.op("add"), [x_main_module_9, x_main_module_10])
x_main_module_12 = mmain.add_instruction(migraphx.op("contiguous"), [x_main_module_11])
x_main_module_13 = mmain.add_instruction(migraphx.op("reshape", dims=[0,32,-1]), [x_main_module_12])
x_main_module_14 = mmain.add_instruction(migraphx.op("reduce_mean", axes=[2]), [x_main_module_13])
x_main_module_15 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,40960]), [x_main_module_14])
x_main_module_16 = mmain.add_instruction(migraphx.op("sqdiff"), [x_main_module_13, x_main_module_15])
x_main_module_17 = mmain.add_instruction(migraphx.op("reduce_mean", axes=[2]), [x_main_module_16])
x_main_module_18 = mmain.add_instruction(migraphx.op("sub"), [x_main_module_13, x_main_module_15])
x_main_module_19 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,1]), [x_main_module_0])
x_main_module_20 = mmain.add_instruction(migraphx.op("add"), [x_main_module_17, x_main_module_19])
x_main_module_21 = mmain.add_instruction(migraphx.op("rsqrt"), [x_main_module_20])
x_main_module_22 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,40960]), [x_main_module_21])
x_main_module_23 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_18, x_main_module_22])
x_main_module_24 = mmain.add_instruction(migraphx.op("contiguous"), [x_main_module_23])
x_main_module_25 = mmain.add_instruction(migraphx.op("reshape", dims=[2,320,64,64]), [x_main_module_24])
x_main_module_26 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,320,64,64]), [x_main_module_5])
x_main_module_27 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_25, x_main_module_26])
x_main_module_28 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,320,64,64]), [x_main_module_4])
x_main_module_29 = mmain.add_instruction(migraphx.op("add"), [x_main_module_27, x_main_module_28])
x_main_module_30 = mmain.add_instruction(migraphx.op("sigmoid"), [x_main_module_29])
x_main_module_31 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_29, x_main_module_30])
mmain.add_instruction(migraphx.op("convolution", padding=[1,1,1,1]), [x_main_module_31, x_main_module_6])
So this will need #3097 and MIGRAPHX_DISABLE_LAYERNORM_FUSION=1 MIGRAPHX_ENABLE_SPLIT_REDUCE=1 to split the fused reduction. Then we need to update MLIR to fuse the pointwise with the reductions.
To do the entire fusions we will also need #3010 and #3113 with MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1, but this issue is focused on the first part of the fusion.
-
[x] 1. Fuse reshapes first.
convolution + reshapes + pointwise--> Always enable this, use #3010 to usefusemodule functionality. -- work completed with https://github.com/ROCm/AMDMIGraphX/pull/3280 -
[ ] 2 . Fuse reductions
convolution + reduction. -
Measure performance with MLIR reduction first and see if it is faster or slower.
-
[ ] - Start with convolution + reduce_mean first and measure for performance.
-
Gate with some ENV var.
-
[ ] 3. Fuse
convolution + reshapes + pointwise + reductionswith multiple outputs. #3097 is necessary for multiple outputs from reduction. -
[ ] - Start with single output and then enable multiple-outputs.
-
Pointwise + reduction module will appear as
fuse_reducemodule. This fusion happens across reshapes but not across transposes. -
All the outputs from fused
convolution + reshapes + pointwise + reductionsmust be reductions. Because otherwise global sync is required. better to fuse pointwise with other convolution. -
[x] - Multi-outputs is supported in MLIR (confirm with rocMLIR team).
-
Look at #3097 to see how it computes output shape inside
split_reduce.cpp -
Make sure output buffer/argument order for the tuple arg is same between migraphx and mlir when launching kernel.
-
[ ] 4. Enable multi-outputs from Pointwise Modules (future/independent work) not related to this work.
- Fuse reshapes first. convolution + reshapes + pointwise --> Always enable this, use https://github.com/ROCm/AMDMIGraphX/pull/3010 to use fuse module functionality.
This address issue #2822.
- [x] Multi-outputs is supported in MLIR (confirm with rocMLIR team).
rocMLIR has support for multiple outputs but MIGraphX needs to create MLIR module to leverage that. Example for two outputs inside MLIR is here
https://github.com/ROCm/rocMLIR/blob/9fb6bacfdb1bb8d0991d417194cf9cf680f9602d/mlir/test/fusion/pr-e2e/multiple-outputs/migraphx-mbcast-two-outputs.mlir
Multiple outputs may not be tested well enough on MLIR side though.
p = migraphx.program() mmain = p.get_main_module() x_main_module_0 = mmain.add_literal(migraphx.create_argument(migraphx.shape(type="float_type", lens=[1]), [1e-4])) x_sample = mmain.add_parameter("sample",migraphx.shape(type="float_type", lens=[2,4,64,64])) x_main_module_4 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,1,1]), 1)) x_main_module_5 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,1,1]), 2)) x_main_module_6 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,320,3,3]), 3)) x_main_module_7 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320]), 4)) x_main_module_8 = mmain.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[320,4,3,3]), 5)) x_main_module_9 = mmain.add_instruction(migraphx.op("convolution", padding=[1,1,1,1]), [x_sample, x_main_module_8]) x_main_module_10 = mmain.add_instruction(migraphx.op("broadcast", axis=1, out_lens=[2,320,64,64]), [x_main_module_7]) x_main_module_11 = mmain.add_instruction(migraphx.op("add"), [x_main_module_9, x_main_module_10]) x_main_module_12 = mmain.add_instruction(migraphx.op("contiguous"), [x_main_module_11]) x_main_module_13 = mmain.add_instruction(migraphx.op("reshape", dims=[0,32,-1]), [x_main_module_12]) x_main_module_14 = mmain.add_instruction(migraphx.op("reduce_mean", axes=[2]), [x_main_module_13]) x_main_module_15 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,40960]), [x_main_module_14]) x_main_module_16 = mmain.add_instruction(migraphx.op("sqdiff"), [x_main_module_13, x_main_module_15]) x_main_module_17 = mmain.add_instruction(migraphx.op("reduce_mean", axes=[2]), [x_main_module_16]) x_main_module_18 = mmain.add_instruction(migraphx.op("sub"), [x_main_module_13, x_main_module_15]) x_main_module_19 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,1]), [x_main_module_0]) x_main_module_20 = mmain.add_instruction(migraphx.op("add"), [x_main_module_17, x_main_module_19]) x_main_module_21 = mmain.add_instruction(migraphx.op("rsqrt"), [x_main_module_20]) x_main_module_22 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,32,40960]), [x_main_module_21]) x_main_module_23 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_18, x_main_module_22]) x_main_module_24 = mmain.add_instruction(migraphx.op("contiguous"), [x_main_module_23]) x_main_module_25 = mmain.add_instruction(migraphx.op("reshape", dims=[2,320,64,64]), [x_main_module_24]) x_main_module_26 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,320,64,64]), [x_main_module_5]) x_main_module_27 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_25, x_main_module_26]) x_main_module_28 = mmain.add_instruction(migraphx.op("multibroadcast", out_lens=[2,320,64,64]), [x_main_module_4]) x_main_module_29 = mmain.add_instruction(migraphx.op("add"), [x_main_module_27, x_main_module_28]) x_main_module_30 = mmain.add_instruction(migraphx.op("sigmoid"), [x_main_module_29]) x_main_module_31 = mmain.add_instruction(migraphx.op("mul"), [x_main_module_29, x_main_module_30]) mmain.add_instruction(migraphx.op("convolution", padding=[1,1,1,1]), [x_main_module_31, x_main_module_6])
Have this case working with mlir-split-reduce branch and https://github.com/ROCm/rocMLIR/pull/1590
Summary:
gpu::code_object::mlir_convolution_reshape_mul_reshape_reduce_sum_reshape_mul_mul_reshape_reduce_sum_reshape: 1.89057ms / 1 = 1.89057ms, 62%
gpu::code_object::mlir_mul_sub_add_rsqrt_sub_mul_mul_add_sigmoid_mul_reshape_convolution: 1.10803ms / 1 = 1.10803ms, 37%
hip::fill: 0.0345808ms / 1 = 0.0345808ms, 2%
get_tuple_elem: 0.00572976ms / 3 = 0.00190992ms, 1%
multibroadcast: 0.0045914ms / 4 = 0.00114785ms, 1%
hip::hip_copy_literal: 0.004223ms / 5 = 0.0008446ms, 1%
reshape_lazy: 0.003228ms / 3 = 0.001076ms, 1%
load: 0.0029358ms / 1 = 0.0029358ms, 1%
broadcast: 0.0012496ms / 1 = 0.0012496ms, 1%
@param: 0.0009894ms / 2 = 0.0004947ms, 1%
hip::hip_allocate_memory: 0.00085ms / 1 = 0.00085ms, 1%
check_context::migraphx::gpu::context: 0.0007968ms / 1 = 0.0007968ms, 1%
Batch size: 1
Rate: 333.267 inferences/sec
Total time: 3.0006ms
Total instructions time: 3.05778ms
Overhead time: 0.0195354ms, -0.0571787ms
Overhead: 1%, -2%
[ MIGraphX Version: 2.11.0.a784df3a2 ] Complete: ./bin/driver perf ../test.py