coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Bug Report: macOS 15 Beta - PyTorch gridsample Not Utilizing Apple Neural Engine on MacBook Pro M2

Open vinayak-sharan opened this issue 1 year ago • 4 comments

Dear Team,

We are encountering an issue with the macOS 15 beta update where PyTorch’s gridsample function is not executing on Apple’s Neural Engine in our MacBook Pro M2.

To help diagnose the problem, I am attaching:

•	A minimal code snippet to reproduce the issue
•	The relevant performance screen shot for macOS15
•	The relevant performance screen shot for macOS14

Let me know if further information is required.

Best regards, Vinayak

To Reproduce

import torch
import coremltools as ct
import torch.nn as nn
import torch.nn.functional as F


class PytorchGridSample(torch.nn.Module):
    def __init__(self, grids):
        super(PytorchGridSample, self).__init__()
        self.upsample1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.upsample2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.upsample3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.upsample4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.upsample5 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)
        self.grids = grids

    def forward(self, x):
        x = self.upsample1(x)

        x = F.grid_sample(x, self.grids[0], padding_mode='reflection', align_corners=False)

        x = self.upsample2(x)

        x = F.grid_sample(x, self.grids[1], padding_mode='reflection', align_corners=False)

        x = self.upsample3(x)

        x = F.grid_sample(x, self.grids[2], padding_mode='reflection', align_corners=False)

        x = self.upsample4(x)

        x = F.grid_sample(x, self.grids[3], padding_mode='reflection', align_corners=False)

        x = self.upsample5(x)

        x = F.grid_sample(x, self.grids[4], padding_mode='reflection', align_corners=False)

        return x


def convert_to_coreml(model, input_):
    traced_model = torch.jit.trace(
        model, example_inputs=input_, strict=False)

    coreml_model = ct.converters.convert(traced_model,
                                         inputs=[ct.TensorType(shape=input_.shape),
                                                 ],
                                         compute_precision=ct.precision.FLOAT16,
                                         minimum_deployment_target=ct.target.macOS14,
                                         compute_units=ct.ComputeUnit.ALL)
    return coreml_model


def main(pt_model, input_):
    """
    Convert a PyTorch model to CoreML
    """

    coreml_model = convert_to_coreml(pt_model, input_)
    coreml_model.save("grid_sample.mlpackage")


if __name__ == "__main__":
    input_tensor = torch.randn(1, 512, 4, 4)
    grids = []
    res = [4, 8, 16, 32, 64, 128]
    for i in res:
        grids.append(torch.randn(1, 2*i, 2*i, 2))
    pt_model = PytorchGridSample(grids)
    main(pt_model, input_tensor)

System environment:

  • coremltools version: 7.2 and 8.0b1
  • OS (e.g. MacOS version or Linux type): MacOS 15
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): Pytorch: 2.1.0

macOS14 performance report

Performance_report_on_macOS14

macOS15 performance report

Performance_report_on_macOS15

vinayak-sharan avatar Aug 07 '24 15:08 vinayak-sharan

Thank you for reporting it! We are working on it to figure out why it falls off ANE.

junpeiz avatar Aug 08 '24 16:08 junpeiz

I also always had problems with gridsample on MacOS. It is so annoying that you cannot control whether an operation is executed on ANE or CPU or on GPU!

ivan-alles avatar Aug 16 '24 10:08 ivan-alles

It's due to the resample op fall off ANE which is caused by some channel validation logic. We are working on it. Thank you for your patience!

junpeiz avatar Aug 26 '24 18:08 junpeiz

Thank you for the update. I am looking forward for the fix. Hopefully before the official release of macOS15 :)

I was wondering if it is due to the bug in coreML library or in macOS15?

vinayak-sharan avatar Aug 28 '24 09:08 vinayak-sharan

It's a bug in ANE op validation (not in CoreML). We are working on a fix and will keep you posted when it's available in the OS. Thanks!

junpeiz avatar Aug 28 '24 16:08 junpeiz

@junpeiz Apologies for closing by mistake :) Looking forward for the fix!

vinayak-sharan avatar Sep 23 '24 06:09 vinayak-sharan

Thank you for your patience! The fix has been landed in MacOS 15.1 beta 5.

I will close this issue, and feel free to re-open it if you find issues trying it with MacOS 15.1 beta 5. Thanks!

junpeiz avatar Oct 07 '24 17:10 junpeiz