[MPS] Improve runtime complexity of `roi_align`
roi_align on MPS has significantly inflated runtime complexity due to a bug in the looping behavior of the kernel. I've not found any other correctness issues with the current implementation, which closely follows the CUDA implementation. This PR fixes the runtime complexity, otherwise the kernel is semantically identical to before.
Note that this PR switches the dispatching to dispatchThreads, which has a tighter build target set than dispatchThreadgroups. Ref Nonuniform threadgroup size in Metal feature set tables.
Some other MPS kernels in vision is also likely affected.
Running the example code from https://github.com/pytorch/pytorch/issues/124850#issue-2261644613 before:
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls Input Shapes
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
model_inference 0.02% 6.412ms 100.00% 41.913s 41.913s 1 []
aten::where 0.00% 4.373us 80.19% 33.611s 8.403s 4 [[1000]]
aten::nonzero_numpy 0.00% 15.335us 80.19% 33.611s 8.403s 4 [[1000]]
aten::nonzero 80.18% 33.605s 80.19% 33.611s 8.403s 4 [[1000]]
aten::where 0.00% 7.375us 2.55% 1.067s 533.698ms 2 [[4507]]
aten::nonzero_numpy 0.00% 11.042us 2.55% 1.067s 533.695ms 2 [[4507]]
aten::nonzero 2.31% 969.133ms 2.55% 1.067s 533.679ms 2 [[4507]]
aten::topk 2.53% 1.062s 2.53% 1.062s 1.062s 1 [[1, 120000], [], [], [], []]
torchvision::nms 0.00% 52.208us 2.39% 1.004s 1.004s 1 [[21, 4], [21], []]
aten::sort 2.39% 999.630ms 2.39% 999.635ms 999.635ms 1 [[21], [], [], []]
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Self CPU time total: 41.913s
and after
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls Input Shapes
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
model_inference 0.88% 4.364ms 100.00% 493.862ms 493.862ms 1 []
torchvision::nms 15.95% 78.782ms 17.20% 84.925ms 84.925ms 1 [[4507, 4], [4507], []]
aten::where 0.00% 2.957us 11.38% 56.185ms 14.046ms 4 [[1000]]
aten::nonzero_numpy 0.00% 7.379us 11.38% 56.182ms 14.045ms 4 [[1000]]
aten::nonzero 10.26% 50.684ms 11.37% 56.146ms 14.036ms 4 [[1000]]
aten::conv2d 0.00% 5.417us 6.39% 31.548ms 31.548ms 1 [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
aten::convolution 0.00% 9.041us 6.39% 31.543ms 31.543ms 1 [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], []]
aten::_convolution 0.00% 12.542us 6.39% 31.534ms 31.534ms 1 [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], [], [], [], [], [], [], []]
aten::_mps_convolution 6.38% 31.520ms 6.38% 31.521ms 31.521ms 1 [[1, 3, 800, 800], [64, 3, 7, 7], [], [], [], [], []]
torchvision::roi_align 5.88% 29.036ms 5.88% 29.047ms 29.047ms 1 [[1, 256, 200, 200], [960, 5], [], [], [], [], []]
-------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ --------------------------------------------------------------------------------
Self CPU time total: 493.862ms
One concern I have with the approach I'm proposing here is numeric overflow of the index with large input sizes.
Fixes https://github.com/pytorch/pytorch/issues/124850
cc @malfet @kulinseth @qqaatw
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9100
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@Isalia20 @qqaatw Do you have time to review this?
Thanks a lot for the review!
Can you add a small script which will measure the time difference directly between the old roi pool and the new one? The one in the main thread is a bit confusing to me since the first section has no roi_pool and the 2nd one does.
I agree that the perf outputs from the first comment is a bit confusing. The culprit looks like it's nonzero, but that's a quirk from the profiler. The time is actually spent in roi_align, and the total execution time is 41.913s. In the second output you can see that the timings has improved significantly and the total execution time is 493.862ms.
I added a regression test test_performance_mps, that checks the execution time against a threshold of 1000 ms. You can run this unit test on this branch and main to see the difference in execution time, just set execution_time_ms_threshold = 0 and you'll get the timings on this branch too.
Have you tested it out on larger input sizes and tested against CPU that this implementation produces equivalent results?
output_size is defined as https://github.com/pytorch/vision/blob/6473b779bdb8ba02bab0fc9e0f4ef4661ebb632a/torchvision/csrc/ops/mps/roi_align_kernel.mm#L37
I've tested it with values generating output_size up to a size of 2^31 and it outputs the same results on CPU and MPS (tested with torch.testing.assert_close).
Above 2^31 I get a crash on CPU with the error Fatal Python error: Segmentation fault. If I try to print the whole tensor on MPS I get
/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
Fatal Python error: Aborted
Indexing into the tensor I get valid output eg. for out[0][0]:
tensor([[73.9383, 61.6012, 74.1146, 72.1870, 71.5774, 81.3736, 68.8621],
[73.6598, 65.7005, 76.3044, 68.5069, 72.9770, 75.2113, 74.2729],
[68.8734, 75.3870, 69.6267, 79.9169, 74.0059, 81.7421, 79.3910],
[73.7394, 72.1691, 64.8541, 68.3909, 78.4569, 75.4807, 76.2083],
[82.0290, 70.3133, 69.1630, 70.7505, 80.5654, 65.7685, 79.4339],
[70.2205, 76.4919, 68.9302, 66.3778, 74.3694, 77.7530, 66.5249],
[88.4454, 65.3945, 83.0347, 66.1287, 63.6279, 66.8136, 84.1742]],
device='mps:0')
but I don't trust the results to be numerically correct - especially considering index likely overflows here. And indexing with out[-1][-1] will again yield
/AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX'
Fatal Python error: Aborted
These errors can be triggered by setting
num_rois = 171196 # < 2**31 -> good
num_rois = 171206 # > 2**31 -> errors
rois = self._make_rois(img_size, num_imgs, dtype, num_rois=num_rois)
Should we add a check on output_size against INT_MAX for MPS? We should probably add a check in CPU as well to prevent a crash, but I consider it out of scope for this PR.
cc @Isalia20