torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[RFC] Adding support for non-constant dims for aten.view

Open gpetters94 opened this issue 3 years ago • 12 comments

I'm working on lowering OPT, and I'm running into the following:

error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
note: see current operation: %932 = "torch.aten.view"(%898, %931) : (!torch.vtensor<[1,12,7,64],f32>, !torch.list<int>) -> !torch.vtensor<[12,7,64],f32>

Inspecting the lowering of aten.view, it looks like the output shape is -1, -1, 64 because the first two input dims aren't constants. The solution I'd like to write is to recursively follow the dims up the tree, verifying that all the ops are either constants, no-ops (i.e. NumToTensor), or math ops (i.e. multiplication, addition) and then performing the math statically to determine the output shape. Does this sound like how we want to implement this?

gpetters94 avatar Aug 01 '22 23:08 gpetters94

The output shape looks like [12, 7, 64] in your snippet and not [-1, -1, 64]. Can you show the actual IR snippet you are dealing with?

silvasean avatar Aug 01 '22 23:08 silvasean

In the actually processing of aten.view, it checks if each input dim is a constant. If not it assigns kUnknownDim to it, and in this case the first two inputs are not constants. The code is here.

gpetters94 avatar Aug 02 '22 01:08 gpetters94

Can you show the IR before the pass?

silvasean avatar Aug 02 '22 01:08 silvasean

(for future reference, it's usually important to show a reduced, fully valid IR example with any bug reports like this)

silvasean avatar Aug 02 '22 01:08 silvasean

Here's the IR after failure: https://gist.github.com/gpetters94/af96b032acb0e6c6274af9aff62ec5e3

The relevant part is:

  %136 = torch.aten.mul.Tensor %123, %71 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
  %137 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %138 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %139 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %140 = torch.prim.ListConstruct %int1, %int7, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %141 = torch.aten.view %126, %140 : !torch.vtensor<[1,7,768],f32>, !torch.list<int> -> !torch.vtensor<[1,7,12,64],f32>
  %142 = torch.aten.transpose.int %141, %int1, %int2 : !torch.vtensor<[1,7,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %143 = torch.aten.contiguous %142, %int0 : !torch.vtensor<[1,12,7,64],f32>, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %144 = torch.aten.numel %143 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %136, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %70, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %139, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %143, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>

gpetters94 avatar Aug 02 '22 01:08 gpetters94

Here's the distilled version:

func.func @forward(%arg0: !torch.vtensor<[1,12,7,64],f32>) -> !torch.vtensor<[12,7,64],f32> {
  %str = torch.constant.str "floor"
  %int7 = torch.constant.int 7
  %int12 = torch.constant.int 12
  %int64 = torch.constant.int 64
  %144 = torch.aten.numel %arg0 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %tensor7 = torch.prim.NumToTensor.Scalar %int7 : !torch.int -> !torch.vtensor<[],si64>
  %tensor64 = torch.prim.NumToTensor.Scalar %int64 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %tensor7, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %tensor64, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %int12, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %arg0, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
  return %150 : !torch.vtensor<[12,7,64],f32>
}

gpetters94 avatar Aug 02 '22 04:08 gpetters94

It looks we have already done all the shape math statically, because the result shape is inferred as !torch.vtensor<[12,7,64],f32>. So I don't want to do any special local logic here for that.

You should be able to extend https://github.com/llvm/torch-mlir/pull/935 for torch.aten.div.Tensor_mode to do more folding here if that is useful as well.

silvasean avatar Aug 02 '22 23:08 silvasean

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

gpetters94 avatar Aug 03 '22 00:08 gpetters94

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

That would make sense to me. Actually, I would add a canonicalization that replaces the view sizes operand with a constant list if the result shape is static (and the operand is not already a constant list).

silvasean avatar Aug 03 '22 19:08 silvasean

Sure, I can do that. Where are canonicalizations added?

gpetters94 avatar Aug 03 '22 23:08 gpetters94

TorchOps.cpp -- you need to add let hasCanonicalizer = 1 the ODS definition.

silvasean avatar Aug 03 '22 23:08 silvasean

See here for more info: https://mlir.llvm.org/docs/Canonicalization/

silvasean avatar Aug 03 '22 23:08 silvasean

Implemented this in #1337

gpetters94 avatar Sep 02 '22 07:09 gpetters94