MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Allow `DataLoader` and `Dataset` to retain `Generic` features from torch

Open stuartthomson opened this issue 11 months ago • 3 comments

Is your feature request related to a problem? Please describe. I would like to be able to provide better type hints for my monai code. The DataLoader and Dataset classes in monai inherit from torch but hide the fact that in torch these are generic classes in torch. For example, in torch I can define a dataset like:

from torch.utils.data import Dataset

class MyData(TypedDict):
    filename: str
    image: torch.Tensor
    segmentation: torch.Tensor

my_dataset: Dataset[MyData] = create_data()

This means that elsewhere in the code I can have a better idea what the data will look like. This kind of thing isn't possible if I'm using the monai code.

Describe the solution you'd like I think you could solve this by doing something like the following (for Dataset - you'd need to do something similar in DataLoader):

import collections.abc
from typing import Any, Mapping, Sequence, TypeVar, Union, overload
import numpy as np
import torch
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset as _TorchSubset


NdarrayOrTensor = Union[np.ndarray, torch.Tensor]


T = TypeVar(
    "T",
    bound=NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
)
class Dataset(_TorchDataset[T]):

    # Leave the rest of the class as-is
    ...

    @overload
    def __getitem__(self, index: slice) -> _TorchSubset[T]:
        ...
    @overload
    def __getitem__(self, index: Sequence[int]) -> _TorchSubset[T]:
        ...
    @overload
    def __getitem__(self, index: int) -> T:
        ...
    
    def __getitem__(self, index: int | slice | Sequence[int]) -> T | _TorchSubset[T]:
        """
        Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.
        """
        if isinstance(index, slice):
            # dataset[:42]
            start, stop, step = index.indices(len(self))
            indices = range(start, stop, step)
            return _TorchSubset(dataset=self, indices=indices)
        if isinstance(index, collections.abc.Sequence):
            # dataset[[1, 3, 4]]
            return _TorchSubset(dataset=self, indices=index)
        return self._transform(index)

stuartthomson avatar Feb 10 '25 15:02 stuartthomson

Let me know if this is something you'd be interested in seeing a PR for - I'd be happy to have a go.

stuartthomson avatar Feb 11 '25 09:02 stuartthomson

Is there anything I can do to get a response on this? I'm happy to try opening a PR if there is agreement that this feature would be good

stuartthomson avatar Oct 08 '25 14:10 stuartthomson

Hi - this seems like a good feature to add. If you can create a PR and all tests pass, we would welcome the contribution! Sorry for the extended silence!

aylward avatar Oct 08 '25 16:10 aylward