mlir-hlo icon indicating copy to clipboard operation
mlir-hlo copied to clipboard

Deduplicate reduction subcomputations when converting from MHLO to HLO

Open hawkinsp opened this issue 3 years ago • 1 comments

See https://github.com/google/jax/issues/7654

We should deduplicate reducers when converting from MHLO to HLO. e.g. compare:

In [1]: import jax

In [2]: import jax.numpy as jnp

In [3]: def f(x, y): return jnp.sum(x) + jnp.sum(y)

In [4]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir())
module @jit_f.2 {
  func.func public @main(%arg0: tensor<10xi32>, %arg1: tensor<15xi32>) -> tensor<i32> {
    %0 = mhlo.constant dense<0> : tensor<i32>
    %1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0] : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
     reducer(%arg2: tensor<i32>, %arg3: tensor<i32>)  {
      %5 = mhlo.add %arg2, %arg3 : tensor<i32>
      "mhlo.return"(%5) : (tensor<i32>) -> ()
    }
    %2 = mhlo.constant dense<0> : tensor<i32>
    %3 = mhlo.reduce(%arg1 init: %2) across dimensions = [0] : (tensor<15xi32>, tensor<i32>) -> tensor<i32>
     reducer(%arg2: tensor<i32>, %arg3: tensor<i32>)  {
      %5 = mhlo.add %arg2, %arg3 : tensor<i32>
      "mhlo.return"(%5) : (tensor<i32>) -> ()
    }
    %4 = mhlo.add %1, %3 : tensor<i32>
    return %4 : tensor<i32>
  }
}

and

In [6]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir(dialect="hlo").as_hlo_text())
HloModule jit_f.4, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}

region_0.4 {
  Arg_0.5 = s32[] parameter(0)
  Arg_1.6 = s32[] parameter(1)
  ROOT add.7 = s32[] add(Arg_0.5, Arg_1.6)
}

region_1.9 {
  Arg_0.10 = s32[] parameter(0)
  Arg_1.11 = s32[] parameter(1)
  ROOT add.12 = s32[] add(Arg_0.10, Arg_1.11)
}

ENTRY main.15 {
  Arg_0.1 = s32[10]{0} parameter(0)
  constant.3 = s32[] constant(0)
  reduce.8 = s32[] reduce(Arg_0.1, constant.3), dimensions={0}, to_apply=region_0.4
  Arg_1.2 = s32[15]{0} parameter(1)
  reduce.13 = s32[] reduce(Arg_1.2, constant.3), dimensions={0}, to_apply=region_1.9
  ROOT add.14 = s32[] add(reduce.8, reduce.13)
}

It would be great to merge region_0.4 and region_1.9 for readability of the HLO. Some computations end up with hundreds of reducers.

@cheshire

hawkinsp avatar Aug 19 '22 19:08 hawkinsp

Some simple identical function merging based on OperationEquivalence should be able to catch this I think.

joker-eph avatar Aug 19 '22 22:08 joker-eph