MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Memory efficient sliding window inference

Open razorx89 opened this issue 2 years ago • 6 comments

Is your feature request related to a problem? Please describe. Large input volumes have to be processed via a sliding window algorithm, otherwise OOMs can happen quickly. There are two constraining properties which can cause an OOM: image size and number of predicted classes. The sliding_window_inference in MONAI allocates a FP32 probability aggregation buffer of size BxCxDxHxW and FP32 weight aggregation buffer of size BxDxHxW. In most cases, we are only interested in the final prediction anyway (class with highest probability).

The following example shows that even without running any model, the peak memory usage is very high (and that is a conservative image size if you think about e.g. high resolution isotropic wholebody CTs):

import torch
from monai.inferers import sliding_window_inference

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_window_inference(data, (128, 128, 128), 1, model)
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from large pool |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from large pool |   14544 MB |   33504 MB |  166490 MB |  151946 MB |
|       from small pool |       0 MB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   35958 MB |   35958 MB |   35958 MB |       0 B  |
|       from large pool |   35956 MB |   35956 MB |   35956 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |   10853 MB |   12646 MB |   12646 MB |
|       from large pool |       0 B  |   10852 MB |   12638 MB |   12638 MB |
|       from small pool |       0 B  |       1 MB |       8 MB |       8 MB |
|---------------------------------------------------------------------------|
| Allocations           |       2    |      10    |     221    |     219    |
|       from large pool |       2    |       9    |     205    |     203    |
|       from small pool |       0    |       3    |      16    |      16    |
|---------------------------------------------------------------------------|
| Active allocs         |       2    |      10    |     221    |     219    |
|       from large pool |       2    |       9    |     205    |     203    |
|       from small pool |       0    |       3    |      16    |      16    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      12    |      12    |      12    |       0    |
|       from large pool |      11    |      11    |      11    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       5    |      58    |      58    |
|       from large pool |       0    |       4    |      53    |      53    |
|       from small pool |       0    |       2    |       5    |       5    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

If the image size or class count is further increased (e.g. 384x384x384) even on very powerful cards like A6000 with 48GB OOMs occur and require moving the aggregation to cpu memory. But since all computations for adding probability crops to the aggregation buffer are performed on the cpu instead of the gpu, it is going to be very slow.

Describe the solution you'd like I implemented a sliding_window_inference_with_reduction method, which essentially performs two sliding windows. The outer loop iterates over a single dimension and performs the reduction operation, thus taking probabilities from the inner loop, apply e.g. argmax and store the results in the output buffer of size BxDxHxW with integer datatype (e.g. uint8). The inner loop performs a 2.5d sliding window inference of a slab of data. Since the outer loop can also iterate with some overlap, some of the probabilities of the inner loop will be used for initializing the buffer for the next inner loop iteration.

import torch

from inferer import sliding_window_inference_with_reduction

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_window_inference_with_reduction(data, (128, 128, 128), 1, model)

print(torch.cuda.memory_summary())
inferer.py
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import (
 compute_importance_map,
 dense_patch_slices,
 get_valid_patch_size,
)
from monai.inferers.utils import _get_scan_interval
from monai.utils import (
 BlendMode,
 PytorchPadMode,
 convert_data_type,
 convert_to_dst_type,
 fall_back_tuple,
 look_up_option,
)
from tqdm import tqdm


def sliding_window_inference_with_reduction(
 inputs: torch.Tensor,
 roi_size: Union[Sequence[int], int],
 sw_batch_size: int,
 predictor: Callable[..., torch.Tensor],
 overlap: float = 0.25,
 mode: Union[BlendMode, str] = BlendMode.CONSTANT,
 sigma_scale: Union[Sequence[float], float] = 0.125,
 padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
 cval: float = 0.0,
 sw_device: Optional[Union[torch.device, str]] = None,
 device: Optional[Union[torch.device, str]] = None,
 reduction_fn: Callable[..., torch.Tensor] = torch.argmax,
 reduction_dim: int = 1,
 output_dtype: torch.dtype = torch.uint8,
 progress: bool = False,
 *args: Any,
 **kwargs: Any,
) -> torch.Tensor:
 compute_dtype = inputs.dtype
 num_spatial_dims = len(inputs.shape) - 2
 if overlap < 0 or overlap >= 1:
     raise ValueError("overlap must be >= 0 and < 1.")

 batch_size, _, *orig_image_size = inputs.shape
 if device is None:
     device = inputs.device
 if sw_device is None:
     sw_device = inputs.device

 roi_size_safe: Tuple[int] = fall_back_tuple(roi_size, orig_image_size)

 image_size = tuple(
     max(orig_image_size[i], roi_size_safe[i]) for i in range(num_spatial_dims)
 )
 pad_size = []
 for k in range(len(inputs.shape) - 1, 1, -1):
     diff = max(roi_size_safe[k - 2] - inputs.shape[k], 0)
     half = diff // 2
     pad_size.extend([half, diff - half])

 if max(pad_size) > 0:
     inputs = F.pad(
         inputs,
         pad=pad_size,
         mode=look_up_option(padding_mode, PytorchPadMode),
         value=cval,
     )

 patch_size = get_valid_patch_size(image_size, roi_size_safe)

 importance_map = compute_importance_map(
     patch_size=patch_size,
     mode=mode,
     sigma_scale=sigma_scale,
     device=sw_device,
 )
 importance_map = torch.clamp(
     importance_map,
     min=max(importance_map[importance_map != 0].min().item(), 1e-3),
 )
 importance_map = convert_data_type(
     importance_map, torch.Tensor, sw_device, compute_dtype
 )[0]

 # Identifiy outer and inner dimensions for the sliding window and aggregation
 outer_dim = 2  # TODO Find heuristic for this

 # Allocate buffers
 output = torch.empty(
     tuple(x for i, x in enumerate(inputs.shape) if i != reduction_dim),
     dtype=output_dtype,
     device=device,
 )

 slab_probabilities: Optional[torch.Tensor] = None
 slab_weights = torch.zeros(
     tuple(
         inputs.shape[i]
         if i != outer_dim + 2  # account for batch and channel dimensions
         else roi_size_safe[outer_dim]
         for i in range(inputs.ndim)
     ),
     dtype=compute_dtype,
     device=sw_device,
 )

 # Iterate over outer dimension and aggregate a full slab of reduced predictions
 outer_step_size = int(roi_size_safe[outer_dim] * (1 - overlap))
 outer_indices = list(
     range(
         0,
         inputs.shape[outer_dim + 2] - roi_size_safe[outer_dim] + 1,
         outer_step_size,
     )
 )
 if outer_indices[-1] != image_size[outer_dim] - roi_size_safe[outer_dim]:
     outer_indices.append(image_size[outer_dim] - roi_size_safe[outer_dim])
 last_outer_dim_idx = -1
 for outer_idx, outer_dim_idx in enumerate(
     tqdm(outer_indices, leave=True, position=0) if progress else outer_indices
 ):
     # Move old probabilities and weights based on the actual step size of this slab
     if outer_idx > 0:
         assert slab_probabilities is not None
         actual_step_size = outer_dim_idx - last_outer_dim_idx
         assert 0 < actual_step_size <= outer_step_size
         new_slices = tuple(
             slice(None)
             if i != outer_dim + 2  # account only for batch dimension
             else slice(None, -actual_step_size)
             for i in range(slab_probabilities.ndim)
         )
         old_slices = tuple(
             slice(None)
             if i != outer_dim + 2  # account only for batch dimension
             else slice(actual_step_size, None)
             for i in range(slab_probabilities.ndim)
         )
         null_slices = tuple(
             slice(None)
             if i != outer_dim + 2  # account only for batch dimension
             else slice(-actual_step_size, None)
             for i in range(slab_probabilities.ndim)
         )

         slab_probabilities[new_slices] = slab_probabilities[old_slices]
         slab_weights[new_slices] = slab_weights[old_slices]

         slab_probabilities[null_slices] = 0.0
         slab_weights[null_slices] = 0.0

     # Take slab of input images and apply padding
     slab_slices = tuple(
         slice(None)
         if i != outer_dim + 2  # account for batch and channel dimensions
         else slice(outer_dim_idx, outer_dim_idx + roi_size_safe[outer_dim])
         for i in range(inputs.ndim)
     )
     slab_input = inputs[slab_slices]

     # Compute crop locations in slab
     scan_interval = _get_scan_interval(
         slab_input.shape[2:], roi_size_safe, num_spatial_dims, overlap
     )
     slices = dense_patch_slices(slab_input.shape[2:], roi_size_safe, scan_interval)
     num_win = len(slices)
     total_slices = num_win * batch_size

     # Perform sliding window inference on slab
     slice_indices = list(range(0, total_slices, sw_batch_size))
     for slice_idx in (
         tqdm(slice_indices, leave=False, position=1) if progress else slice_indices
     ):
         # Get crops from slices
         slice_range = range(slice_idx, min(slice_idx + sw_batch_size, total_slices))
         unravel_slice = [
             [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)]
             + list(slices[idx % num_win])
             for idx in slice_range
         ]
         window_data = torch.cat(
             [
                 convert_data_type(slab_input[win_slice], torch.Tensor)[0]
                 for win_slice in unravel_slice
             ]
         ).to(sw_device)

         # Compute probabilities and aggregate
         probabilities = predictor(window_data, *args, **kwargs)

         if slab_probabilities is None:
             output_classes = probabilities.shape[1]
             slab_probabilities = torch.zeros(
                 (batch_size, output_classes)
                 + tuple(
                     image_size[i] if i != outer_dim else patch_size[outer_dim]
                     for i in range(len(image_size))
                 ),
                 dtype=compute_dtype,
                 device=sw_device,
             )

         probabilities *= importance_map.unsqueeze(0).unsqueeze(0)
         for slice_idx, win_slice in enumerate(unravel_slice):
             slab_probabilities[win_slice] += probabilities[
                 slice_idx : slice_idx + 1
             ]
             slab_weights[win_slice] += importance_map

     # Apply reduction operation and move partial output to output buffer
     assert slab_probabilities is not None
     assert slab_weights is not None
     copy_size = (
         roi_size_safe[outer_dim]
         if outer_idx == len(outer_indices) - 1
         else outer_step_size
     )
     reduction_slices = tuple(
         slice(None)
         if i != outer_dim + 2  # account for batch and channel dimensions
         else slice(None, copy_size)
         for i in range(slab_probabilities.ndim)
     )
     predictions = reduction_fn(
         slab_probabilities[reduction_slices] / slab_weights[reduction_slices],
         dim=reduction_dim,
     )
     output[
         tuple(
             slice(None)
             if i != outer_dim + 1  # account for batch dimension
             else slice(outer_dim_idx, outer_dim_idx + copy_size)
             for i in range(output.ndim)
         )
     ] = predictions

     last_outer_dim_idx = outer_dim_idx

 # Crop to original image size
 output = output[
     ..., : orig_image_size[0], : orig_image_size[1], : orig_image_size[2]
 ]

 if isinstance(inputs, MetaTensor):
     return convert_to_dst_type(output, inputs, device=device)[0]

 return output

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from large pool |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from small pool |       0 KB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from large pool |  184320 KB |   16008 MB |  103790 MB |  103610 MB |
|       from small pool |       0 KB |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   22486 MB |   22486 MB |   22486 MB |       0 B  |
|       from large pool |   22484 MB |   22484 MB |   22484 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |   12288 KB |    5308 MB |   75658 MB |   75646 MB |
|       from large pool |   12288 KB |    5308 MB |   75654 MB |   75642 MB |
|       from small pool |       0 KB |       1 MB |       4 MB |       4 MB |
|---------------------------------------------------------------------------|
| Allocations           |       2    |      11    |     168    |     166    |
|       from large pool |       2    |      11    |     162    |     160    |
|       from small pool |       0    |       3    |       6    |       6    |
|---------------------------------------------------------------------------|
| Active allocs         |       2    |      11    |     168    |     166    |
|       from large pool |       2    |      11    |     162    |     160    |
|       from small pool |       0    |       3    |       6    |       6    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      11    |      11    |      11    |       0    |
|       from large pool |      10    |      10    |      10    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       1    |       6    |      49    |      48    |
|       from large pool |       1    |       6    |      46    |      45    |
|       from small pool |       0    |       2    |       3    |       3    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

In this example the peak memory usage is reduced in half. Further, it runs currently the outer loop on the third spatial dimension, so if you increase the size of this dimension the peak memory usage stays constant. E.g. 384x384x384 results in an OOM on A6000 with 48GB, but has still 16GB of peak memory usage using the optimized sliding window algorithm.

Current limitations before being able to integrate it into MONAI:

  • Does not support multi-output models
  • Does not support resizing outputs back to crop size
  • Does not dynamically select the dimension for the outer loop (which could be an interesting heuristic for trade-off between GPU memory consumption and utilization)

Describe alternatives you've considered Running sliding_window_inference with device='cpu', but it is of course significantly slower, since all computations regarding the aggregation buffer are performed on cpu instead of gpu.

Another approach could be to apply the same two step sliding window approach, but perform all the aggregation on the gpu and move the fully aggregated probabilities of a slab to the huge aggregation buffer in cpu memory. Then we have just a large copy operation instead of cpu-based computations for adding and normalizing probabilities.

Discussion What do you think about this approach and would you like to integrate it into MONAI? In my opinion especially for inference or full image validation during training, where we only need the class with the highest probability, it would be a much more efficient implementation. This would allow remaining memory to be used for complex model computations instead of just a data storage. Maybe it could be further enhanced to also support applying multiple reductions, e.g. argmax and a FP measurement of uncertainty.

razorx89 avatar Apr 25 '23 07:04 razorx89

thanks for the insights, we recently added a similar idea of buffering with the buffer_dim and buffer_steps parameters: https://github.com/Project-MONAI/MONAI/blob/9c9777751ab4f96e059a6597b9aa7ac6e7ca3b92/monai/inferers/utils.py#L121-L127

https://github.com/Project-MONAI/MONAI/discussions/6157#discussioncomment-5491346

It's available in monai 1.2.0rc5. we haven't explored the reduction_fn idea yet. cc @myron

wyli avatar Apr 25 '23 07:04 wyli

We also added SlidingWindowInfererAdapt class to automatically manage memory without OOM, which you can use as a replacement for SlidingWindowInferer.

myron avatar Apr 25 '23 20:04 myron

That is good to know, thanks. But still, a more memory efficient algorithm for just receiving the predicted class index would help increasing inference speed. It seems like the SlidingWindowInfererAdapt just tries different settings until it does not get an OOM anymore. That further increases inference times, especially if the function is only called once (e.g. a predict.py script for a single image).

razorx89 avatar Apr 27 '23 09:04 razorx89

@razorx89 thank you for a great example/issue and the code with evaluation. So, SlidingWindowInfererAdapt() simplifies memory management by attempting to run optimally within GPU budged, and it will use try/except.

You can however run SlidingWindowInferer() directly with "buffered" mode, which will be somewhat similar to your suggestion (if you know ahead of time you have lower gpu mem).

(on dev branch of monai)

import torch

from monai.inferers import SlidingWindowInferer
sliding_inferer = SlidingWindowInferer(roi_size=[128,128,128],  overlap=0.25, buffer_steps=1, device='cpu')

num_classes = 100
data = torch.rand(1, 1, 384, 384, 256, dtype=torch.float32, device='cuda:0')
model = lambda x: x * torch.rand(x.shape[0], num_classes, x.shape[2], x.shape[3], x.shape[4], dtype=x.dtype, device=x.device)

out = sliding_inferer(data, model)

print(torch.cuda.memory_summary())
|===========================================================================|                                                                                   
|                  PyTorch CUDA memory summary, device ID 0                 |                                                                                   
|---------------------------------------------------------------------------|                                                                                   
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |                                                                                   
|===========================================================================|                                                                                   
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from large pool | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from small pool |      0 KiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from large pool | 147456 KiB |   9752 MiB | 136952 MiB | 136808 MiB |
|       from small pool |      0 KiB |      0 MiB |      0 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   9766 MiB |   9766 MiB |   9766 MiB |      0 B   |
|       from large pool |   9764 MiB |   9764 MiB |   9764 MiB |      0 B   |
|       from small pool |      2 MiB |      2 MiB |      2 MiB |      0 B   |
|---------------------------------------------------------------------------|
| Non-releasable memory |      0 B   |  14335 KiB |  14337 KiB |  14337 KiB |
|       from large pool |      0 B   |  12288 KiB |  12288 KiB |  12288 KiB |
|       from small pool |      0 B   |   2047 KiB |   2049 KiB |   2049 KiB |
|---------------------------------------------------------------------------|
| Allocations           |       1    |       6    |     152    |     151    |
|       from large pool |       1    |       6    |     149    |     148    |
|       from small pool |       0    |       3    |       3    |       3    |
|---------------------------------------------------------------------------|
|---------------------------------------------------------------------------|
| Active allocs         |       1    |       6    |     152    |     151    |
|       from large pool |       1    |       6    |     149    |     148    |
|       from small pool |       0    |       3    |       3    |       3    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       7    |       7    |       7    |       0    |
|       from large pool |       6    |       6    |       6    |       0    |
|       from small pool |       1    |       1    |       1    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       2    |       2    |       2    |
|       from large pool |       0    |       1    |       1    |       1    |
|       from small pool |       0    |       1    |       1    |       1    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

this will run inference and intermediate stitching on gpu, and has lower peak-memory than your example. (you can set buffer_steps=2 , to achieve a similar peak memory). in this "buffered" mode, the peak memory shouldn't increase even if your input image size is larger in z-axis (same as you mentioned). PS: for some reason I couldn't run your code (some errors), so if you can please "time it" vs example above (with buffer_steps=1 and buffer_steps=2), we can see if the runtime is much different.

notice here we get the probability (float) output, and we can do a) ensembling b) resampling to invert to the original resolution (if we trained at resampled resolution). Instead if the results is only after argmax then we (mostly) lose these abilities. So the application of your approach seems to focused on a specific use-case when we only do a single model inference at fixed resolution. If it's really much faster than SlidingWindowInferer, then can consider a PR internally. Or you're always welcome to submit a PR too.

myron avatar May 05 '23 07:05 myron

for some reason I couldn't run your code (some errors), so if you can please "time it" vs example above (with buffer_steps=1 and buffer_steps=2), we can see if the runtime is much different.

Yeah, there were some changes to the utility functions and overlap is now considered to be a tuple.

Here are the timings for the above example (average over 10 iterations, version 1.2.dev2318). There seem to be some additional improvements in sliding_window_inference between v1.1 and dev. My implementation is based on the v1.1 version.

v1.1 - default:     9.420s
dev - default:      7.531s
dev - buffer=1:    10.071s
dev - buffer=2:     9.984s
dev - output=cpu: 203.805s
dev - reduction:    9.111s

However, regarding the memory consumption, the buffered implementation has in my experiments a higher peak memory usage:

dev - default:   16.695 GiB
dev - buffer=1:  23.727 GiB
dev - buffer=2:  27.438 GiB
dev - reduction: 15.633 GiB

Edit: I forgot the device="cpu" for buffered mode, that is why the peak memory usage is higher. Additionally, here are also the results of my implementation with device="cpu".

timings:

dev - buffer=1 cpu:  120.505s
dev - buffer=2 cpu:  109.838s
dev - reduction cpu:  11.553s

peak memory usage:

dev - buffer=1 cpu:   9.523 GiB
dev - buffer=2 cpu:  14.797 GiB
dev - reduction cpu: 15.598 GiB

notice here we get the probability (float) output, and we can do a) ensembling b) resampling to invert to the original resolution (if we trained at resampled resolution).

a) can still be done by wrapping all models into one model, assuming that the crop size is the same for all models in the ensemble.

b) True, but in most MONAI tutorials this is not the case. In most examples, the postprocessing pipeline executes a AsDiscrete transformation and resamples afterwards, which is basically the same sceneario.

https://github.com/Project-MONAI/tutorials/blob/c014b03c0425eddbf2beed5490dc246543ddd2b4/modules/dynunet_pipeline/inferrer.py#L75 https://github.com/Project-MONAI/tutorials/blob/c014b03c0425eddbf2beed5490dc246543ddd2b4/modules/dynunet_pipeline/inferrer.py#L130-L142

These examples only handle resampling in preprocessing and also apply only torch.argmax at the end without postprocessing transforms: https://github.com/Project-MONAI/tutorials/tree/main/3d_segmentation

And resizing huge probability maps (class count > 100) back to original resolution may take forever or even generate OOMs, thus, resampling artifacts may be acceptable.

razorx89 avatar May 05 '23 09:05 razorx89

thank you for the response, I can see it's useful for your case, and some other certain cases. I will let other people to comment.

myron avatar May 08 '23 20:05 myron