tiatoolbox icon indicating copy to clipboard operation
tiatoolbox copied to clipboard

⚡️Add `torch.compile` Functionality

Open Abdol opened this issue 2 years ago • 21 comments

This draft PR involves integrating PyTorch 2.0's torch.compile functionality to demonstrate performance improvements in torch code. This PR focuses on adding torch.compile to PatchPredictor.

Notes:

  • According to the documentation, noticeable performance can be achieved when using modern NVIDIA GPUs (H100, A100, or V100)
  • ~~Python 3.11+ is not yet supported for torch.compile~~ UPDATE 1: Python 3.11 support was added in PyTorch 2.1. UPDATE 2: Python 3.12 support was added in PyTorch 2.4.

TODO:

  • [x] Resolve compilation errors related to using torch.compile in running models
  • [x] Initial config
  • [x] Add to patch predictor
  • [x] Add to registration
  • [x] Add to segmentation
  • [ ] Test on custom models
  • [ ] Test on torch.compile compatible GPUs

Abdol avatar Sep 29 '23 11:09 Abdol

https://github.com/pytorch/pytorch/releases/tag/v2.1.0 PyTorch 2.1 now supports torch.compile in Python 3.11

shaneahmed avatar Oct 05 '23 10:10 shaneahmed

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.88%. Comparing base (9113996) to head (020d9ef). Report is 14 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop     #716   +/-   ##
========================================
  Coverage    99.88%   99.88%           
========================================
  Files           69       69           
  Lines         8702     8717   +15     
  Branches      1148     1149    +1     
========================================
+ Hits          8692     8707   +15     
  Misses           4        4           
  Partials         6        6           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Oct 06 '23 17:10 codecov[bot]

After adding torch.compile to the SemanticSegmentor, I'm getting this error from test_feature_extractor.py:

self = <torch._dynamo.output_graph.OutputGraph object at 0x7fbbb97ee880>
gm = GraphModule(
  (self_backbone_conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (...ize=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (self_clf): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)

    @dynamo_timed(phase_name="backend_compile")
    def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        tot = 0
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
        torch._dynamo.utils.increment_op_count(tot)
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            # WrapperBackend needs real inputs, for now, to verify correctness
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn, self.example_inputs())
    
            # NOTE: [Real Tensors in Accuracy Evaluation]
            #
            # Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs()
            # call to compiler_fn below. At runtime, backends use real tensors.
            #
            # This should be a strong invariant we hold across all backends,
            # and generally, it is. However, for accuracy evaluation, we need real tensors at compile time,
            # for now, due to the unfortunate setup described below.
            #
            # Due to the nature of how we invoke comparison as a backend in two different ways:
            #
            # (1) Less bad, but still worth rewriting, WrapperBackend above, which takes
            # real inputs for its ctor. see the config.verify_correctnes above.
            #
            # (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as
            # a true backend, and therefore needs to be compiled with real inputs. This is made trickier
            # by the fact that the minifier will spawn new processes during minification. As such, we have
            # created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned
            # as part of accuracy minification. This flag is not a contract, and ideally will not be here long.
            #
            # The longer term PoR is to:
            # (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same
            # correctness and accuracy logic, so as not to have two different ways of doing the same thing.
            #
            # (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to
            # pass real tensors to it at compile time.
            is_top_level_minifying = (
                config.repro_after is not None and config.repro_level == 4
            )
            if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
                compiled_fn = compiler_fn(gm, self.example_inputs())
            elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
                compiled_fn = compiler_fn(gm, self.example_inputs())
            else:
                compiled_fn = compiler_fn(gm, self.fake_example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except Exception as e:
            compiled_fn = gm.forward
>           raise BackendCompilerFailed(self.compiler_fn, e) from e
E           torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised **RuntimeError: Inference tensors do not track version counter.**
E           
E           While executing %unsqueeze_1 : [#users=1] = call_method[target=unsqueeze](args = (%self_upsample2x_unpool_mat, 0), kwargs = {})
E           Original traceback:
E             File "/home/u2271662/tia/projects/tiatoolbox/code/tiatoolbox/tiatoolbox/models/architecture/utils.py", line 136, in forward
E               mat = self.unpool_mat.unsqueeze(0)  # 1xshxsw
E            |   File "/home/u2271662/tia/projects/tiatoolbox/code/tiatoolbox/tiatoolbox/models/architecture/unet.py", line 408, in forward
E               x_ = self.upsample2x(x)
E           
E           
E           Set torch._dynamo.config.verbose=True for more information
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               torch._dynamo.config.suppress_errors = True

It seems like that further adjustments to the way the models are implemented as well as their inference code need to change so that torch.compile should work.

Would appreciate any feedback to overcome this error.

UPDATE: the error doesn't occur when using PyTorch 2.1.0.

Abdol avatar Oct 20 '23 11:10 Abdol

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

UPDATE: moved to description.

TODO in this PR:

  • [x] Resolve compilation errors related to using torch.compile in running models
  • [ ] Break into smaller PRs for each specific task (initial config, patch prediction, segmentation, registration)
  • [ ] Test on custom models
  • [ ] Test on torch.compile on compatible GPUs

Abdol avatar Nov 17 '23 12:11 Abdol

I have managed to try torch.compile out on a compatible GPU, namely the V100 on the DGX-2. I ran some tests from the PyTorch tutorial and tried the three main complile modes (default, reduce-overhead, and max-autotune). Up until now, eager execution is faster or ever so slightly so compared to any of the torch.compile modes. Here is a sample of the tests (with reduce-overhead compile mode):

Test 4: DenseNet on torch.compile eager: 2.552837158203125 compile: 94.6605

Test 5: Performance comparison (DenseNet) eager train time 0: 0.3816151123046875 eager train time 1: 0.05365350341796875 eager train time 2: 0.04904755020141602 eager train time 3: 0.04890316772460938 eager train time 4: 0.04844646453857422 eager train time 5: 0.04913971328735352 eager train time 6: 0.047429630279541016 eager train time 7: 0.04739583969116211 eager train time 8: 0.04733849716186524 eager train time 9: 0.04743577575683594 eager train time 10: 0.04716339111328125 eager train time 11: 0.04967731094360352 eager train time 12: 0.04951347351074219 eager train time 13: 0.04763750457763672 eager train time 14: 0.04769279861450195 eager train time 15: 0.048020481109619144 eager train time 16: 0.04770406341552735 eager train time 17: 0.04748185729980469 eager train time 18: 0.04775219345092773 eager train time 19: 0.047424510955810545 eager train time 20: 0.04731391906738281 eager train time 21: 0.04680704116821289 eager train time 22: 0.04696985626220703 eager train time 23: 0.04726681518554687 eager train time 24: 0.047058944702148435

compile train time 0: 678.564875 compile train time 1: 0.06348083114624023 compile train time 2: 0.05787955093383789 compile train time 3: 0.05670809555053711 compile train time 4: 0.05652070236206055 compile train time 5: 0.05630156707763672 compile train time 6: 0.05644083023071289 compile train time 7: 0.056220672607421876 compile train time 8: 0.056223743438720705 compile train time 9: 0.05616230392456055 compile train time 10: 0.05612851333618164 compile train time 11: 0.05630771255493164 compile train time 12: 0.05731942367553711 compile train time 13: 0.0573675537109375 compile train time 14: 0.05924966430664062 compile train time 15: 0.05782425689697265 compile train time 16: 0.05713100814819336 compile train time 17: 0.057957374572753906 compile train time 18: 0.056390655517578124 compile train time 19: 0.057071617126464844 compile train time 20: 0.056787967681884766 compile train time 21: 0.05672550582885742 compile train time 22: 0.05709209442138672 compile train time 23: 0.056461311340332034 compile train time 24: 0.057278465270996094

Next, I will run tests from the toolbox and investigate whether there can be potential significant performance gains.

Abdol avatar Jan 19 '24 15:01 Abdol

Updating PR to fix issues with codecov integration.

shaneahmed avatar Apr 29 '24 11:04 shaneahmed

@Abdol With PyTorch 2.3 release, please can you check if it is compatible with Python 3.12?

shaneahmed avatar Jun 21 '24 11:06 shaneahmed

@shaneahmed according to release notes, PyTorch 2.3 still does not fully support Python 3.12 including torch.compile. They plan to support it in 2.4.

Abdol avatar Jun 24 '24 09:06 Abdol

@Abdol @shaneahmed, PyTorch 2.4 now supports Python 3.12 for torch.compile(): https://github.com/pytorch/pytorch/releases/tag/v2.4.0

GeorgeBatch avatar Jul 28 '24 08:07 GeorgeBatch

Will get back to this PR from next week.

Abdol avatar Sep 20 '24 11:09 Abdol

No update this week. Will have a look at the error next week.

Abdol avatar Oct 04 '24 08:10 Abdol

I'm getting a DeepSource error related to code outside this PR: link to error. Can anyone have a look and let me know if this needs fixing in this PR (or in another one)? Thanks!

Abdol avatar Oct 10 '24 17:10 Abdol

I'm getting a DeepSource error related to code outside this PR: link to error. Can anyone have a look and let me know if this needs fixing in this PR (or in another one)? Thanks!

As discussed before you can ignore this error.

shaneahmed avatar Oct 10 '24 17:10 shaneahmed

To explain the coverage discrepancy, it is due to temporarily disabling torch.compile in def helper_tile_info() in the test_nucleus_instance_segmentor.py (lines 46-53):

def helper_tile_info() -> list:
    """Helper function for tile information."""
    torch._dynamo.reset()
    current_torch_compile_mode = rcParam["torch_compile_mode"]
    rcParam["torch_compile_mode"] = "disable"
->  predictor = NucleusInstanceSegmentor(model="A")
    torch._dynamo.reset()
    rcParam["torch_compile_mode"] = current_torch_compile_mode
...

There was an error due to torch.compile failing to optimise a model defined as "A" (see snippet above). So I have disabled torch.compile before model definition (for that case specifically) and reverted back afterwards.

Abdol avatar Oct 10 '24 18:10 Abdol

The PR is almost done now. torch.compile is now added in patch extractor, registration, and segmentation (both semantic and instance). Since the instance segmentor is a sub-class of the semantic segmentor, I only added torch.compile to the semantic segmentor and the instance segmentor should inherit it accordingly.

Now, before merging, I suggest testing with some real case scenarios and data, particularly patch extraction and segmentation. I need some help to do that. Can anybody pull this branch and try torch.compile on some existing code they have? (should be run on a compatible GPU like H100, A100, or V100 for best results). Maybe @measty /@mostafajahanifar / @GeorgeBatch can you have a look? Happy to help with setting up config.

Abdol avatar Oct 11 '24 11:10 Abdol

@Abdol, I am happy to try it out in the next few days. What did you mean by setting up the config?

GeorgeBatch avatar Oct 17 '24 15:10 GeorgeBatch

I will test it using V100 GPU. I can't get the environment to work on our A100 I think the driver is too old it needs update.

Jiaqi-Lv avatar Oct 18 '24 12:10 Jiaqi-Lv

@Abdol, I am happy to try it out in the next few days. What did you mean by setting up the config?

Hi @GeorgeBatch. Thanks for volunteering to help test this out. You can enable torch.compile by adjusting rcParams in TIAToolbox:

import torch
from tiatoolbox import rcParam
torch._dynamo.reset() # include this line every time `torch.compile` mode is to be changed
rcParam["torch_compile_mode"] = "default" # other modes: "reduce-overhead", "max-autotune", or "disable"

Then you can go about running any existing TIAToolbox-based scripts that involve patch prediction, segmentation, or registration. You can also have a look/try out the unit tests written for torch.compile for each corresponding functionality (check out Files changed tab for info).

Abdol avatar Oct 21 '24 10:10 Abdol

@Abdol, I only had scripts for feature extraction from patches for now, so that will test patch prediction functionality.

Do you know if I should expect any speedup on NVIDIA A6000 nodes? Or does it only apply to V100, A100, H100 nodes?

If it works well and speeds up the process, I might introduce it into my workflow before the next release. Can you please clarify my question regarding the units and resolution I should use for patch extraction: (issue https://github.com/TissueImageAnalytics/tiatoolbox/issues/874) so I do not waste compute time with incorrectly set parameters?

GeorgeBatch avatar Oct 21 '24 12:10 GeorgeBatch

A100 has just got updated. I will test it on that too.

Jiaqi-Lv avatar Oct 23 '24 14:10 Jiaqi-Lv

@Abdol, I only had scripts for feature extraction from patches for now, so that will test patch prediction functionality.

Do you know if I should expect any speedup on NVIDIA A6000 nodes? Or does it only apply to V100, A100, H100 nodes?

If it works well and speeds up the process, I might introduce it into my workflow before the next release. Can you please clarify my question regarding the units and resolution I should use for patch extraction: (issue #874) so I do not waste compute time with incorrectly set parameters?

Hi @GeorgeBatch. Thank you for your reply. I don't think the A6000 will produce significant speed ups but you can try anyway if you like.

Regarding issue #874, I think @measty / @Jiaqi-Lv can clarify patch parameters better than me 🙂

Abdol avatar Oct 25 '24 07:10 Abdol

Currently doing some bench marking using PatchPredictor on A100: Test data is a very large WSI. Each comparison was done 3 times. Models are TIAToolBox pretained models.

Patch size = 224x224; Number of patches = 29632

  • Model: ResNet101-kather100k; Batch size: 64; Compile Method: default; Average Speed Up: 1.24x ±0.11; Average_time_with_compile: 590s; Average_time_without_compile: 724s
  • Model: wide_resnet101_2-kather100k; Batch size: 16, Compile Method: default; Average Speed Up: 1.22x ±0.03; Average_time_with_compile: 2190s; Average_time_without_compile: 2662s
  • ...

Patch size = 96x96; Number of patches = 39872

  • Model: resnext101_32x8d-pcam; Batch size: 32; Compile Method: default; Average Speed Up: 0.96x
  • Model: wide_resnet101_2-pcam; Batch size: 32; Compile Method: default; Average Speed Up: 0.99x
  • ...

The first forward pass is always extremely slow with torch.compile, the subsequent forward passes get faster. If the number of iterations are too small torch.compile would not produce speed up.

Some people suggest it's best to call .cuda() before .compile(), but there is no official docs about this yet. https://discuss.pytorch.org/t/torch-compile-before-or-after-cuda/176031 People have reported more significant speed-up for training.

Jiaqi-Lv avatar Oct 25 '24 14:10 Jiaqi-Lv

Benchmark for SemanticSegmentor using A100 Data: 91424 patches, patch_size=1024x1024

  • Model: fcn_resnet50_unet-bcss; Batch size: 8; Compile Method: default; Average Speed Up: 1.3x; Average_time_with_compile: 3505s; Average_time_without_compile: 4558s

Jiaqi-Lv avatar Nov 01 '24 13:11 Jiaqi-Lv

Thank you @Jiaqi-Lv for testing and @shaneahmed for your review 🙂

Abdol avatar Nov 04 '24 11:11 Abdol

@Abdol Do you know why this is failing tests? Only Python 3.12 is passing.

shaneahmed avatar Nov 11 '24 11:11 shaneahmed