ffcv icon indicating copy to clipboard operation
ffcv copied to clipboard

Segmentation Fault when using FFCV alongside PyTorch DataLoader (Possible bug)

Open numpee opened this issue 2 years ago • 0 comments

I've just encountered a weird issue where the PyTorch DataLoader will segfault in the following conditions:

  • FFCV and PyTorch dataloaders are both created
  • FFCV uses more than 2 workers

Basically, the FFCV loader is bound to 2 workers max until it causes a segfault in the PyTorch DataLoader. Any ideas why this might occur?

Please check out the code snippet below to reproduce the error.

import numpy as np
import torch
from ffcv import Loader
from ffcv.fields.basics import IntDecoder
from ffcv.fields.rgb_image import SimpleRGBImageDecoder
from ffcv.loader import OrderOption
from ffcv.transforms import RandomHorizontalFlip, ToTensor, ToDevice, ToTorchImage, NormalizeImage, Squeeze
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor as TT
from tqdm import tqdm


def main():
    pytorch_workers = 8
    ffcv_workers = 3    # Code works when ffcv_workers <= 2
    beton_path = 'cifar10_val.beton'    # Your FFCV .beton file here
    image_pipeline = [SimpleRGBImageDecoder(), RandomHorizontalFlip(), ToTensor(),
                      ToDevice(torch.device(0), non_blocking=True), ToTorchImage(),
                      NormalizeImage(mean=np.array([0.5, 0.5, 0.5]) * 255, std=np.array([1., 1., 1., ]) * 255,
                                     type=np.float32)]
    label_pipeline = [IntDecoder(), ToTensor(), Squeeze(), ToDevice(torch.device(0), non_blocking=True)]
    loader = Loader(beton_path, batch_size=16, num_workers=ffcv_workers, order=OrderOption.SEQUENTIAL, os_cache=True,
                    drop_last=False, pipelines={'image': image_pipeline, 'label': label_pipeline})
    dataset = MNIST(root="./", download=True, transform=TT())
    pytorch_dataloader = DataLoader(dataset, batch_size=16, num_workers=pytorch_workers, pin_memory=True)

    for i, *_ in enumerate(tqdm(pytorch_dataloader)):
        print(i)
        if i == 10:
            break

if __name__ == "__main__":
    main()

numpee avatar Feb 24 '23 16:02 numpee