How to implement batch sampler on webdataset?
Hi everyone, hope you are doing well wanted to ask a technical question regarding webdataset. I was trying to implement a costum batch sampler function. The issue is the following, using dataloader from torch we can do the following, having a custom batch sampler function.
class ExpSampler:
def __init__(self, dataset, random: bool = True):
self.dataset = dataset
self.exps = self.dataset.df["experiment"].unique()
self.random = random
def __iter__(self):
indexes = np.arange(len(self.exps))
if self.random:
np.random.shuffle(indexes)
for exp_id in indexes:
exp = self.exps[exp_id]
mask = self.dataset.df["experiment"] == exp
all_wells = np.array(self.dataset.df[mask].index)
yield all_wells
def __len__(self):
return len(self.exps)
then pass this sampler to the batch_sampler
train_dl = DataLoader(
train_data,
num_workers=12,
pin_memory=True,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
)
Using webdataset : we first create the dataset:
dataset = webdataset.WebDataset(
file_names,
resampled=False,
nodesplitter=webdataset.split_by_node,
shardshuffle=False,
empty_check=False,
handler=log_and_continue,
)
and then the loader:
loader = webdataset.WebLoader(
dataset.batched(16, collation_fn=ban_full_lib_collate_fn),
num_workers=num_workers,
persistent_workers=False,
pin_memory=True,
)
Searching the documentation of webdataset cannot find a way to create a custom sampler, this is something that should be done in any bioML project to prevent experimental batch effects (each batch gets data from one experimental condition only)
PyTorch has two fundamentally different forms of datasets: indexed datasets, and iterable datasets. Both are recognized by DataLoader but are treated very differently by the PyTorch library. This is just the way PyTorch is written, independent of WebDataset. Only indexed datasets use samplers. There are two corresponding libraries in webdataset: wids for indexed datasets and webdataset for iterable datasets.
The recommended and scalable way of dealing with this is to write a custom batch function, as described in issue #448. The recommended way of DDP training with an iterable dataset is using resampling.
The wids library is closer to what you are used to (custom sampler, etc.) and is easier to use with DDP. However, writing efficient custom samplers for large datasets is more work.