[Feature Request] Datasets Should Use New `torchvision.io` Image Loader APIs and Return `TVTensor` Images by Default
🚀 The feature
- Add "torchvision" image loader backend based on new
torchvision.ioAPIs (See: Release Notes v0.20) and enable it by default. - VisionDatasets should return
TVTensorimages by default instead ofPIL.Image.
Motivation, pitch
- TorchVision v0.20 introduces new
torchvision.ioAPIs that enhance its encoding/decoding capabilities. - Current VisionDatasets returns
PIL.Imageby default, but the first step of transforms is usuallytransforms.ToImage(). - PIL is slow (See: Pillow-SIMD), especially when compared with new
torchvision.ioAPIs. - Current TorchVision image loader backends are based on PIL or accimage, not including new
torchvision.ioAPIs.
Alternatives
- The return type of datasets can be
PIL.Imagewhen using the PIL or the accimage backends, and beTVTensorif using new APIs (may lose consistency).
Additional context
I would like to make a pull request if the community likes this feature.
Hi @fang-d , thank you for the feature request. This is a great idea, and I think the torchvision decoders are in a stable enough state to enable this now.
We already support the loader parameter for some datasets (mostly ImageFolder I think https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder). But we should enable the same for all existing datasets.
I think the way to go would probably be to add that loader parameter to all datasets that currently call Image.open(...).
~/dev/vision (main*) » git grep Image.open torchvision/datasets nicolashug@nicolashug-fedora-PF2MMKSN
torchvision/datasets/_optical_flow.py: img = Image.open(file_name)
torchvision/datasets/_stereo_matching.py: img = Image.open(file_path)
torchvision/datasets/_stereo_matching.py: disparity_map = np.asarray(Image.open(file_path)) / 256.0
torchvision/datasets/_stereo_matching.py: disparity_map = np.asarray(Image.open(file_path)) / 256.0
torchvision/datasets/_stereo_matching.py: disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py: depth = np.asarray(Image.open(file_path))
torchvision/datasets/_stereo_matching.py: disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py: valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
torchvision/datasets/_stereo_matching.py: off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
torchvision/datasets/_stereo_matching.py: disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
torchvision/datasets/_stereo_matching.py: valid_mask = Image.open(mask_path)
torchvision/datasets/caltech.py: img = Image.open(
torchvision/datasets/caltech.py: img = Image.open(
torchvision/datasets/celeba.py: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
torchvision/datasets/cityscapes.py: image = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/cityscapes.py: target = Image.open(self.targets[index][i]) # type: ignore[assignment]
torchvision/datasets/clevr.py: image = Image.open(image_file).convert("RGB")
torchvision/datasets/coco.py: return Image.open(os.path.join(self.root, path)).convert("RGB")
torchvision/datasets/dtd.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/fgvc_aircraft.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/flickr.py: img = Image.open(img_id).convert("RGB")
torchvision/datasets/flickr.py: img = Image.open(filename).convert("RGB")
torchvision/datasets/flowers102.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/folder.py: img = Image.open(f)
torchvision/datasets/food101.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/gtsrb.py: sample = PIL.Image.open(path).convert("RGB")
torchvision/datasets/imagenette.py: image = Image.open(path).convert("RGB")
torchvision/datasets/inaturalist.py: img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
torchvision/datasets/kitti.py: image = Image.open(self.images[index])
torchvision/datasets/lfw.py: img = Image.open(f)
torchvision/datasets/lsun.py: img = Image.open(buf).convert("RGB")
torchvision/datasets/omniglot.py: image = Image.open(image_path, mode="r").convert("L")
torchvision/datasets/oxford_iiit_pet.py: image = Image.open(self._images[idx]).convert("RGB")
torchvision/datasets/oxford_iiit_pet.py: target.append(Image.open(self._segs[idx]))
torchvision/datasets/phototour.py: img = Image.open(fpath)
torchvision/datasets/rendered_sst2.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/sbd.py: img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/sbu.py: img = Image.open(filename).convert("RGB")
torchvision/datasets/stanford_cars.py: pil_image = Image.open(image_path).convert("RGB")
torchvision/datasets/sun397.py: image = PIL.Image.open(image_file).convert("RGB")
torchvision/datasets/voc.py: img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/voc.py: target = Image.open(self.masks[index])
torchvision/datasets/voc.py: img = Image.open(self.images[index]).convert("RGB")
torchvision/datasets/widerface.py: img = Image.open(self.img_info[index]["img_path"]) # type: ignore[arg-type]
Hi @NicolasHug
I think since ImageNet dataset already supports loader parameter, for consistency reason we could also add kwargs to Country211 and EuroSAT datasets?
Hi @NicolasHug, I'd like to tackle this issue, but there're some concerns that I'd like to ask here first.
-
would it be ok to move
default_loader,pil_loader, andaccimage_loaderfrom the originalfolder.pytovision.py? the change would enable other classes to set the defaultloadervalue todefault_loader, like what's currently inImageNetwithout having to importfolder, taken some of the dataset classes are not a subclass ofImageFolder. but if this change counts as a BC change, then we'll need to keep it still importable from the originalfoldermodule? -
We could add
torchvision.io.decode_imageas one of the supported loader, but still keep the default loader as PIL? -
This change would raise the inconsistency issue mentioned here
Alternatives
- The return type of datasets can be
PIL.Imagewhen using the PIL or the accimage backends, and beTVTensorif using new APIs (may lose consistency).
but since the loader and transform are both determined by users, probably this wouldn't be a huge problem?
- The return type of datasets can be
Hi @GdoongMathew ,
I think since ImageNet dataset already supports loader parameter, for consistency reason we could also add kwargs to Country211 and EuroSAT datasets?
Yes good point, we should also add the loader parameter to all datasets that inherit from ImageDataset. We shouldn't need to use **kwargs for that, a new loader parameter should work.
would it be ok to move
Unfortunately, no... This would be BC-breaking :/ . I'm not sure I understand the concern though - is there any issue with having the datasets import things from folder.py?
We could add torchvision.io.decode_image as one of the supported loader, but still keep the default loader as PIL?
I think we'll need to keep PIL as the default loader, yes. But we can encourage users to rely on decode_image in the docstring.
Overall and to answer your final point, I don't think we should implement all this via a new "image backend". The image backend logic is really old and I'm not even sure we still support accimage at all. I think this is better done by simply adding a loader parameter to most datasets - this is more obvious what's going on. And like you said, when a user explicitly passes loader=decode_image, then it's perfectly fine for the dataset to return a Tensor, since this is what the user asked for.
Thanks a lot for working on this!
Well the 1st point is more about the dependency? From what I understand, most of the datasets are subclassed from vision.VisionDataset, whereas the default_loader is located in folder.py module. To me it's like moving to a higher level of the import chain. But like what you mentioned, this would cause a BC change, which is much less desirable.
Hi, @NicolasHug , I think there's another problem tho..
What if a dataset has it labeling as mask, should it use the provided loader to load the mask data? or should the dataset use another argument, something like mask_loader?
If we use the provided loader, then we probably will need to add additional argument for the default_loader and pil_loader so that dataset could read images/mask in different mode, but then it'll have to figure out what argument to pass for which loader function
Or, for the time being, we just expose loader, and not change anything about how one read its mask?
Great point about loading masks - lets' just take it step by step first, maybe we can start with the classification datasets as a first step. Let's look into the detection/segmentation datasets as a second iteration.
Hi @NicolasHug , here're the rest of the datasets that still needs further refactor:
- optical flow
- [x] FlowDataset
- [x] Sintel
- [x] KittiFlow
- [x] FlyingThings3D
- [x] HD1K
- stero matching
- [ ] StereoMatchingDataset
- [ ] CarlaStereo
- [ ] Kitti2012Stereo
- [ ] Kitti2015Stereo
- [ ] Middlebury2014Stereo
- [ ] CREStereo
- [ ] FallingThingsStereo
- [ ] SceneFlowStereo
- [ ] SintelStereo
- [ ] InStereo2k
- [ ] ETH3DStereo
- dataset with masks annotation
- [ ] Cityscapes
- [ ] OxfordIIITPet
- [ ] SBDataset
- [ ] _VOCBase (VOCSegmentation / VOCDetection)
- dataset with keypoints / bboxes / contours annotation
- [ ] Caltech101
- [ ] CelebA
- [ ] CocoDetection / CocoCaptions
- [ ] Kitti
- [ ] WIDERFace
- Classification only dataset
- [ ] Caltech256
- datasets that store images in other format
- [ ] CIFAR10 / CIFAR100
- [ ] FER2013
- [ ] GTSRB ( To enable loading images using
torchvision.ioapi, the api will also have to supportppmimages) - [ ] FlyingChairs (ppm images)
- [ ] LSUNClass / LSUN
- [ ] MNIST / FashionMNIST / KMNIST / EMNIST / QMNIST
- [ ] PCAM
- [ ] PhotoTour
- [ ] SEMEION
- [ ] STL10
- [ ] SVHN
- [ ] USPS
I think we could keep working on the datasets in 1, 2, 4, 5, and leave the one in 3 and 6 for further discussion?
Might be related to:
- https://github.com/pytorch/vision/issues/4975
- https://github.com/pytorch/vision/issues/4991