Deform conv2d mps support
DeformConV2 MPS Support
This PR feature a full MPS implementation of the DeformConV2 operator.
-
Generel notes For consistency and maintainability, this implementation is in many ways similar to the CPU and CUDA implementations, with the obvious exceptions relating to the difference in their frameworks. The Metal part of the implementation is likewise similar to the implementations for the “ROI” related kernel implementations.
-
Tests The implementation passes all tests in test_ops.py:TestDeformConv except for the two optests.generate_opcheck_test: "test_autograd_registration" and "test_aot_dispatch_dynamic"
"test_autograd_registration" fails due to the missing MPS dispatch key in torch/testing/_internal/optests/autograd_registration.py I have addressed this issue in a separate PR: torch.testing._internal.optests - MPS Support #151758
test_aot_dispatch_dynamic fails due to issues that are not yet clear to me.
-
Issues - To be fixed The CPU implementation: deform_conv2d_kernel.cpp is using the in-placed torch operator: .addmm. However deform_conv2d_kernel.mm. However, for reasons unknown to me, using .addmm in the MPS implementation, returns zero-value tensors after the first iteration in the convolution loop. As a temporary solution, I have chosen to use the out-of-place version: addmm instead. This is not ideal and should be fixed.
-
MSL implementation - mps_kernels.h and mps_helpers.h implementations. The implementation of the bilinear_interpolate function used by the “ROI”-related kernels is different from that used by the CPU and CUDA implementations of the deform_conv2 operator. Currently, I have chosen to keep both implementations in mps_kernels.h and named the function: bilinear_interpolate_deform_conv2.
-
Suggestions However, I suggest that mps_kernels.h be split into separate kernels, one for each ”ROI”-related operator, and one for the “deform_conv2” operator. Future implementations of ops should have their own separate kernel files. This will not only be in keeping with the implementation design in Pytorch but also lead to safer and more maintainable code in the long run. I also suggest that any common utility functions and constants found in mps_kernels.h be moved into mps_helpers.h in the future. Also, maybe consider renaming mps_helpers.h to a more generic name.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9026
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Didn't find following labels among repository labels: topic: not user facing
@pytorchbot label "enhancement" @pytorchbot label "module: ops" @pytorchbot label "module: vision"
@pytorchbot label "topic: not user facing"
Didn't find following labels among repository labels: topic: not user facing
I test it,it is very slow: Running on cpu Time: 0.005s Output: torch.Size([1, 8, 128, 128])
Running on mps Time: 0.292s Output: torch.Size([1, 8, 128, 128]) I built it from source code.