Please change dcgan to load truncated images.
When I was using the dcgan example, it had:
Traceback (most recent call last):
File "dcgan.py", line 220, in <module>
for i, data in enumerate(dataloader, 0):
File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
data = self._next_data()
File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 403, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 137, in __getitem__
sample = self.loader(path)
File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 173, in default_loader
return pil_loader(path)
File "/opt/conda/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 156, in pil_loader
return img.convert('RGB')
File "/opt/conda/lib/python3.8/site-packages/PIL/Image.py", line 902, in convert
self.load()
File "/opt/conda/lib/python3.8/site-packages/PIL/ImageFile.py", line 255, in load
raise OSError(
OSError: image file is truncated (150 bytes not processed)
So I added:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
to examples/tree/master/dcgan/main.py and that fixed it Thanks!
It would also be really helpful is we could log the filename of the truncated image
@cynthia-rempel would you like to make a PR with this fix?
Agreed, would be very useful!
I tried switching to this dataset: https://huggingface.co/datasets/cats_vs_dogs
@cynthia-rempel 's solution above addresses the truncated image issue, but I'm hitting some kind of snag during training, which causes values to turn to 'nan', after which (as you can imagine) training progress is broken.
Has anyone experienced similar / does anyone have any ideas on how to start debugging this issue?
For what it's worth, I created an image checker as well (probably logging too much for this example project, but useful for the above debugging):
from pathlib import Path
from PIL import Image, ImageFile
# Load truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Check images are not corrupt and conform to rules
img_rules = {
"suffix": ['jpg', 'jpeg', 'png'],
"formats": ['JPEG', 'PNG'],
"size": 1000,
"modes": ['RGB']
}
img_counts = {
"total": 0,
"rejected": 0,
}
def is_image(path):
try:
img_counts["total"] += 1
file = Path(path)
# Check image is correct suffix
suffix = f"{file.suffix.replace('.','')}"
if suffix not in img_rules['suffix']:
print(f"{path} >> ❌ Suffix")
img_counts["rejected"] += 1
return False
# Open image and check each rule
img = Image.open(path)
rej = None
if img.format not in img_rules['formats']:
rej = f"Format: {img.format}"
if img.size[0] > img_rules['size'] or img.size[1] > img_rules['size']:
rej = f"Size: {img.size}"
if img.mode not in img_rules['modes']:
rej = f"Mode: {img.mode}"
if rej:
print(f"{path} >> ❌ {rej}")
img_counts["rejected"] += 1
return False
return True
except:
return False
print(
f"{img_counts['total']} Images | ❌ Rejected: {img_counts['rejected']} | ✅ Accepted: {img_counts['total'] - img_counts['rejected']}")
You can then enable the above checker function for every loaded image, by adding the is_valid_file param to datasets.ImageFolder:
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
],),
is_valid_file=is_image
)