Removing graph breaks in transforms
This issue tracks progress on graph breaks removal for the v2 transforms. Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.
Kernels
The low-levels kernels are almost all fine. Only 4 kernels are problematic.
import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F
img = torch.rand(3, 256, 256)
# These kernels don't have graph breaks
# -------------------------------------
# torch.compile(F.get_dimensions_image, fullgraph=True)(img)
# torch.compile(F.get_num_channels_image, fullgraph=True)(img)
# torch.compile(F.get_size_image, fullgraph=True)(img)
# torch.compile(F.erase_image, fullgraph=True)(img, 0, 0, 10, 10, v=torch.tensor(0.5))
# torch.compile(F.adjust_brightness_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_contrast_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_gamma_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_hue_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_saturation_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_sharpness_image, fullgraph=True)(img, .5)
# torch.compile(F.autocontrast_image, fullgraph=True)(img)
# torch.compile(F.invert_image, fullgraph=True)(img)
# torch.compile(F.permute_channels_image, fullgraph=True)(img, [2, 1, 0])
# torch.compile(F.posterize_image, fullgraph=True)(img, bits=3)
# torch.compile(F.rgb_to_grayscale_image, fullgraph=True)(img)
# torch.compile(F.solarize_image, fullgraph=True)(img, .4)
# torch.compile(F.affine_image, fullgraph=True)(img, angle=20, translate=[1, 4], scale=1.3, shear=[0, 0])
# torch.compile(F.center_crop_image, fullgraph=True)(img, output_size=(223, 223))
# torch.compile(F.crop_image, fullgraph=True)(img, 0, 10, 10, 10)
# torch.compile(F.elastic_image, fullgraph=True)(img, displacement=torch.randn(1, *img.shape[-2:], 2))
# torch.compile(F.five_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.horizontal_flip_image, fullgraph=True)(img)
# torch.compile(F.pad_image, fullgraph=True)(img, [2, 2, 2, 2])
# torch.compile(F.rotate_image, fullgraph=True)(img, angle=30)
# torch.compile(F.ten_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.vertical_flip_image, fullgraph=True)(img)
# torch.compile(F.gaussian_blur_image, fullgraph=True)(img, kernel_size=3)
# torch.compile(F.normalize_image, fullgraph=True)(img, mean=0, std=1)
# torch.compile(to_dtype_image, fullgraph=True)(img, dtype=torch.uint8, scale=True)
# These ones have breaks
# torch.compile(F.perspective_image, fullgraph=False)(img, None, None, coefficients=torch.rand(8))
# torch.compile(F.resize_image, fullgraph=False)(img, size=(223, 223))
# torch.compile(F.resized_crop_image, fullgraph=False)(img, 0, 12, 10, 34, (223, 223))
# This one doesn't even compile
# torch.compile(F.equalize_image, fullgraph=False)(img)
Weird thing: resize_image and resized_crop_image both break on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L228, but when calling them both consecutively, one of them starts breaking on https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_geometry.py#L234 as well. I have no idea why.
Functionals
As @pmeier noted offline the functionals break on
https://github.com/pytorch/vision/blob/68161e98aaeaeca02166063d19de92e81ea00c3b/torchvision/transforms/v2/functional/_utils.py#L99
which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)
TODO: figure out whether the call to log_api_usage_once() introduces a break.
Transforms
The transforms also break where the functionals break.
On top of that the random transforms seem to break on the call to if rand() < self.p although I don't see those breaks when using TORCH_LOGS="graph_breaks", I only see them when using _dynamo.explain(). And _dynamo.explain() in turn doesn't show the graph breaks that happens on the _KERNEL_REGISTRY. :man_shrugging:
TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.
CC @pmeier @vfdev-5
I've run a few quick benchmarks whether or not it is useful to compile kernels in the first place. I've used a simple classification pipeline (random_resized_crop, horizontal_flip, to_dtype, normalize) and pure tensor input:
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 279 | 225
functional | 280 | 328
Times are in microseconds (us).
- Compiling the kernels works without graph breaks (iff we include a fix for the AVX2 graph break in resize as reported in https://github.com/pytorch/vision/issues/8056#issue-1955064062) leads to a ~20% speedup
- Compiling the functionals leads to ~20% slowdown
The slowdown in the functionals stems from the graph break mentioned of _get_kernels that is the heart of our dispatch mechanism and thus present in every functional. If we hardcode the kernel, e.g.
# kernel = _get_kernel(horizontal_flip, type(inpt))
kernel = horizontal_flip_image
we get the following results
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 270 | 228
functional | 270 | 225
Times are in microseconds (us).
Meaning, if we can somehow resolve the graph break, compiling the functionals will net us the same speedup as compiling the kernels directly. Note that this for now only applies to pure tensors and thus image only pipelines.
I'll be working on this item:
This one doesn't even compile torch.compile(F.equalize_image, fullgraph=False)(img)
=> PR on pytorch: https://github.com/pytorch/pytorch/pull/112753
EDIT: Wrong conclusion:
~Additional torch compile failures for boxes and seg masks:~
...
torch.compile doesn't yet handle tensor subclasses. From this error message
Argument displacement shape should be (1, 1, 4, 2), but given torch.Size([1, 17, 11, 2])
you can see that likely a tensor image made its way into a bounding box kernel.
What exactly are you testing there? That bounding box / mask inputs work properly on a compiled functional?
Well, I was running tests from https://github.com/pytorch/vision/pull/8092/ and it is partially my fault as I was running dispatched functions on tensors instead of subclasses... Now, the problem is with recursive error due tv_tensors.wrap which we can temporarily decorate to skip from compilation
There are two sources of graph breaks in the way we currently dispatch:
-
We use the dispatcher and the input type directly as dictionary keys:
https://github.com/pytorch/vision/blob/15c166ac127db5c8d1541b3485ef5730d34bb68a/torchvision/transforms/v2/functional/_utils.py#L15-L16
This is currently not supported by dynamo. However, there is pytorch/pytorch#111196 that opens up dictionary keys to other types than primitives as well. If that is merged, we should be able to send a small fix to allow our use case as well.
-
Inlining functions that use types, which is what happens when dynamo hits
_get_kernelthe first time, is not properly supported. I have pytorch/pytorch#113340 to address this.
Apart from that, nothing needs to change on our side. Dynamo is fine with all the other things we worried about, i.e. global dicts, MRO traversal, ... :tada:
I've reran my benchmark with fixes for the points above and this is what I got out:
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 265 | 230
functional | 270 | 240
Times are in microseconds (us).
I've re-run it a couple of times and the 10µs gap between compiled kernels and functionals is reproducable. Meaning the compiled functionals don't fully get to the same level as the kernels, but they still outperform their eager counterpart.
One thing that I noticed while playing around with the benchmarks is that dynamo does not give us a strict improvement for individual ops.
random_resized_crop
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 178 | 206
functional | 178 | 207
Times are in microseconds (us).
horizontal_flip
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 22 | 36.4
functional | 24 | 41.7
Times are in microseconds (us).
to_dtype
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 65.2 | 54.6
functional | 67.0 | 59.3
to_type and normalize
[------------------ -----------------]
| eager | compiled
1 threads: ----------------------------
kernel | 170 | 61.4
functional | 180 | 67.5
- resizing and
horizontal_flipis slower in the compiled version that in eager -
to_dtypeis marginally faster -
normalize(with prefixedto_dtypesincenormalizerequires floating point input) is massively faster. IIUC, the high values in eager come from the fact that we are inputting an image with CHW memory layout and that hurtsnormalize. In the full pipeline this is mitigated by having the resize before that produces artificial HWC layout. The compiled version seems to have this natively.
Note that what's going to be great for torchvision is that I expect pretty much any combination of transformation to be fused into one kernel. There is where the main speed-ups will be coming from.
To this end, it'd be useful to try to benchmark through a list of transformation applied one after the other. As I told victor, I expect these wins to heavily overweight the slight regression in resize and flips.
On a different note, I'd expect the flip issue to be fixable.
Thanks a lot for this great investigation Philip.
@lezcano I tend to have a different intuition from yours: if resize is much faster than compiled(resize), then perhaps the speed-up gained with not compiling resize will outweight the speed-up coming from fusing resize with the op coming just before and the one coming just after (keeping the rest of the transforms compiled / fused as well). But we'll see with benchmarks. Regardless, we probably don't need to worry too much about benchmarks for now, the main goal of this issue is to remove graph breaks as a first step.
Few other findings on failing tests when kernels are compiled with variable input shape: https://gist.github.com/vfdev-5/5b2733b5641d08c6889a17eda6267aba (logs contain 32k lines totally, so browser may stuck for few seconds on loading...)