coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

May some bug on custom layer...

Open lgyStoic opened this issue 1 year ago • 3 comments

🐞Describing the bug

  • this bug is quite hard to represent...
  • TLDR:
  • I create an custom layer, also convert success in coremltools, but show log of warning log on CoreML.framework, like below image

To Reproduce

  • I wrote a minimal demo to reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import coremltools
from collections import OrderedDict

import coremltools.proto.FeatureTypes_pb2 as ft
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import (
    _get_inputs as mil_get_inputs, is_symbolic,_get_scales_from_output_size
)
from coremltools.converters.mil import (
    register_torch_op
)
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil import (
    Operation,
    types
)
from coremltools.converters.mil.mil.input_type import (
    InputSpec,
    TensorInputType,
)


@register_torch_op(torch_alias=['grid_sample'], override=True)
def grid_sampler(context, node):
    # https://github.com/pytorch/pytorch/blob/00d432a1ed179eff52a9d86a0630f623bf20a37a/aten/src/ATen/native/GridSampler.h#L10-L11
    inputs = mil_get_inputs(context, node, expected=5)
    x = mb.custom_op(
        x=inputs[0],
        coordinates=inputs[1],
        name=node.name,
    )
    context.add(x)

@register_op(is_custom_op=True)
class custom_op(Operation):
    input_spec = InputSpec(
        x=TensorInputType(type_domain="T"),
        coordinates=TensorInputType(type_domain="T"),
    )

    type_domains = {
        "T": (types.fp16, types.fp32),
        "U": (types.int32,),
    }
    bindings = {'class_name': 'CustomGridSample',
                'input_order': ['coordinates', 'x'],
                'description': "custom grid sampler!"
                }

    def __init__(self, **kwargs):
        super(custom_op, self).__init__(**kwargs)

    def type_inference(self):
        input_shape = self.x.shape
        coord_shape = self.coordinates.shape


        ret_shape = list(input_shape)
        ret_shape[2] = coord_shape[1]  # Output height
        ret_shape[3] = coord_shape[2]  # Output width
        return types.tensor(self.x.dtype, ret_shape)
########################################################################
######################## Test ml model ################################


IN_WH = 512
GRID_WH = 256


class TestModel(nn.Module):

    def __init__(self):
        super(TestModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)

    def forward(self, x, grid):
        x =F.relu(self.conv1(x))

        x = F.grid_sample(x, grid)
        x = F.relu(self.conv2(x))
        return x


########################################################################
########################################################################

def convert(output_path):
    torch_model = TestModel()
    # torch_model = torch.jit.load('./flow_480x272_250103.pt', map_location='cpu')
    example_input = torch.rand(1, 3, IN_WH, IN_WH)
    example_grid = torch.ones(1, GRID_WH, GRID_WH, 2)
    # example_input = torch.rand(1, 1, 272, 480)
    # traced_model = torch.jit.trace(torch_model, (example_input, example_input))
    traced_model = torch.export.export(torch_model, (example_input, example_grid))
    mlmodel = coremltools.convert(
        traced_model,
        inputs=[
            coremltools.TensorType(name="input0", shape=example_input.shape),
            coremltools.TensorType(name="input1", shape=example_grid.shape),
        ],
        convert_to="neuralnetwork",
        # convert_to="milinternal",
        # convert_to="mlprogram",
        minimum_deployment_target=coremltools.target["iOS13"]
    )
    print(mlmodel)
    mlmodel_path = output_path + ".mlmodel"
    mlmodel.save(mlmodel_path)


    print(f"Saved to {output_path}")


def main():
    convert('test')


if __name__ == "__main__":
    main()

using this code can generate an simplest nn net in mlmodel, then loading in objective-c project just the using API

id  model = [MLModel modelWithContentsOfURL:modelUrl
                                  error:&error];

will cause this error log dump in console. image

I don't know whats wrong on this network infer...Also I cannot judge, this is coremltools bug ?or CoreML framework bug? or some bug in my custom op?

System environment (please complete the following information):

  • coremltools version: try 7.2, 8.0,8.1,
  • pytorch version: 2.4.0, 2.4.1
  • OS : try 14.4, 14.5

Additional context

@YifanShenSZ I'm not sure if there are any bugs in my toy code, but if you have some free time, would you mind reviewing it for me?

lgyStoic avatar Jan 07 '25 08:01 lgyStoic

more: convert to mil, show every shape is correct

main[CoreML3](%x: (1, 3, 512, 512, fp32)(Tensor),
              %grid: (1, 256, 256, 2, fp32)(Tensor)) {
  block0() {
    %x_to_fp16: (1, 3, 512, 512, fp16)(Tensor) = cast(x=%x, dtype="fp16", name="cast_2")
    %conv2d_cast_fp16: (1, 16, 510, 510, fp16)(Tensor) = conv(x=%x_to_fp16, weight=%p_conv1_weight_to_fp16, bias=%p_conv1_bias_to_fp16, strides=[1, 1], pad_type="valid", pad=[0, 0, 0, 0], dilations=[1, 1], groups=1, name="conv2d_cast_fp16")
    %relu_cast_fp16: (1, 16, 510, 510, fp16)(Tensor) = relu(x=%conv2d_cast_fp16, name="relu_cast_fp16")
    %grid_to_fp16: (1, 256, 256, 2, fp16)(Tensor) = cast(x=%grid, dtype="fp16", name="cast_1")
    %grid_sampler_cast_fp16: (1, 16, 256, 256, fp16)(Tensor) = custom_op(x=%relu_cast_fp16, coordinates=%grid_to_fp16, name="grid_sampler_cast_fp16")
    %conv2d_1_cast_fp16: (1, 32, 254, 254, fp16)(Tensor) = conv(x=%grid_sampler_cast_fp16, weight=%p_conv2_weight_to_fp16, bias=%p_conv2_bias_to_fp16, strides=[1, 1], pad_type="valid", pad=[0, 0, 0, 0], dilations=[1, 1], groups=1, name="conv2d_1_cast_fp16")
    %relu_1_cast_fp16: (1, 32, 254, 254, fp16)(Tensor) = relu(x=%conv2d_1_cast_fp16, name="relu_1_cast_fp16")
    %relu_1: (1, 32, 254, 254, fp32)(Tensor) = cast(x=%relu_1_cast_fp16, dtype="fp32", name="cast_0")
  } -> (%relu_1)
}

lgyStoic avatar Jan 07 '25 08:01 lgyStoic

There is some misunderstanding in how custom op works

Principle Overview

If you want a custom op that is beyond the op set provided by Core ML, then in principle you need to

  1. Define the custom op in coremltools (AOT)
  2. Define the custom kernel and register in Core ML framework (runtime)

Practise

Starting from mlprogram, Idk if Core ML framework still provides custom kernel registration, so we need to decompose the custom op also in AOT

There are 2 ways you can achieve it:

  1. (Recommended) Directly use the decomposition in torch op translation
@register_torch_op(torch_alias=['grid_sample'], override=True)
def grid_sampler(context, node):
    ... (create output with existing MIL ops) ...
    context.add(x)
  1. Define custom MIL op, and a graph pass to decompose it into standard MIL ops

YifanShenSZ avatar Jan 08 '25 20:01 YifanShenSZ

Btw, are you hitting "cannot convert torch.grid_sample" issue? If so, could you please create a minimal reproduce so we can help you implement?

I'll @ you in the PR so you can learn how to implement torch op translation to help with https://github.com/apple/coremltools/issues/2415, if you cannot find a reproduce for torch.maxpool to let us help

YifanShenSZ avatar Jan 08 '25 20:01 YifanShenSZ