Batched IterableDataset
Feature request
Hi,
could you add an implementation of a batched IterableDataset. It already support an option to do batch iteration via .iter(batch_size=...) but this cannot be used in combination with a torch DataLoader since it just returns an iterator.
Motivation
The current implementation loads each element of a batch individually which can be very slow in cases of a big batch_size. I did some experiments here and using a batched iteration would speed up data loading significantly.
Your contribution
N/A
This is exactly what I was looking for. It would also be very useful for me :-)
This issue is really smashing the selling point of HF datasets... The only workaround I've found so far is to create a customized IterableDataloader which improves the loading speed to some extent.
For example I've a HF dataset dt_train with len(dt_train) == 1M. Using plain DataLoader is extremely slow:
%%time
dl_train = DataLoader(dt_train, batch_size=128, shuffle = True)
for batch in dl_train:
pass
CPU times: user 24min 35s, sys: 704 ms, total: 24min 36s
Wall time: 24min 37s
And DataLoader works even worse with HF's iterable_dataset:
%%time
dt_train_ = dt_train.with_format(None).to_iterable_dataset(num_shards=64).shuffle(buffer_size=10_000)
dl_train = DataLoader(dt_train_, batch_size=128)
for batch in dl_train:
pass
CPU times: user 1h 6min 2s, sys: 4.28 s, total: 1h 6min 6s
Wall time: 1h 7min 53s
Workaround by running a customized wrapper:
%%time
from torch.utils.data import DataLoader, IterableDataset
class Dataset2Iterable(IterableDataset):
"""
Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
"""
def __init__(self, dataset, batch_size=1, shuffle=True):
super(Dataset2Iterable).__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if self.shuffle: self.dataset.shuffle()
return self.dataset.iter(batch_size=self.batch_size)
dl_train = DataLoader(Dataset2Iterable(dt_train, batch_size = 128), batch_size=1, num_workers=0)
for n in range(2):
for batch in dl_train:
pass
The speed still is slower than using tensorflow's loader but improved a lot than previous code:
CPU times: user 4min 18s, sys: 0 ns, total: 4min 18s
Wall time: 4min 20s
Note that the way I implemented Dataset2Iterable will only work with num_workers=0.
I can confirm that @zhh210's solution works with num_workers=0. However, for my use case, this was still slower than tokenizing on the fly through a collator and leveraging multiple workers in the dataloder.
@lhoestq I think this is an important use case (e.g., streaming from a large dataset, online or stored on disk). What do you think might be the best solution to move forward?
I guess it can be implemented using a batched.map() under the hood that returns a single item containing the input batch.
In the meantime you can use this:
def batch(unbatched: dict[str, list]) -> dict[str, list]:
return {k: [v] for k, v in unbatched}
batched_dataset = dataset.map(batch, batched=True, batch_size=batch_size)
Though it would be great to have a .batch() method indeed, I'd be happy to help with anyone wants to open a PR
If no one else is planning to work on this, I can take it on. I'll wait until next week, and if no one has started a PR by then, I'll go ahead and open one.
It looks like the implementation of IterableDataset is still using a hardcoded batch size of 1. For example in line 2063 in /datsets/src/datasets/iterable_dataset.py. Iterating over IterableDataset with large batch sizes therefore remains slow, even when using batch(). I guess then the data are not being read from one contiguous chunk of memory. Instead every example is retrieved one by one, leading to long dataloading times. As a minimal example: Load c4 dataset and iterate over it with a large batch size.
import datasets
from timeit import default_timer as timer
c4 = datasets.load_dataset("allenai/c4", "en", streaming=True, split="train")
c4_batched = c4.batch(512**2) # use large batch size
iterator = iter(c4_batched)
for i in range(5):
start_time=timer()
next(iterator) # get next batch
end_time = timer()
print(f"time for one batch: {end_time-start_time}")
This results in the following output for me: time for one batch: 12.615376660600305 time for one batch: 13.011422813870013 time for one batch: 14.157325950451195 time for one batch: 14.225894245319068 time for one batch: 13.898222777992487
Because I want to use my IterableDataset with the pytorch dataloader I rewrote the __iter_pytorch__ and the __iter__ functions like so and am getting much faster dataloading times. I marked the lines I changed with "# changed here":
from datasets.iterable_dataset import _convert_to_arrow
from datasets.formatting import TensorFormatter, get_formatter
from datasets.features.features import cast_to_python_objects
import sys
import fsspec.asyn
from itertools import islice
from datasets.utils.logging import get_logger
from datasets.iterable_dataset import _examples_to_batch, _apply_feature_types_on_batch, _apply_feature_types_on_example
logger = get_logger(__name__)
def __iter__(self):
if "torch" in sys.modules:
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None:
# We're a torch.utils.data.IterableDataset in a PyTorch worker process
yield from self._iter_pytorch()
return
ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=self.batch_size, drop_last_batch=self.drop_last_batch) # changed here
if self._formatting:
formatter = get_formatter(self._formatting.format_type, features=self.features)
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
)
else:
format_dict = None
if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"):
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
iterator = _convert_to_arrow(ex_iterable, batch_size=self.batch_size) # changed here
for key, pa_table in iterator:
yield formatter.format_row(pa_table)
return
for key, example in ex_iterable:
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
example = _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self._token_per_repo_id
)
yield format_dict(example) if format_dict else example
def _iter_pytorch(self):
ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=self.batch_size, drop_last_batch=self.drop_last_batch) # changed here
# Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0)
# See https://github.com/fsspec/gcsfs/issues/379
fsspec.asyn.reset_lock()
# check if there aren't too many workers
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers:
logger.warning(
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). "
f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers."
)
logger.info(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. "
f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}."
)
# split workload
_log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers)
if shards_indices:
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards."
)
ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers)
self._state_dict = ex_iterable._init_state_dict()
if self._starting_state_dict:
ex_iterable.load_state_dict(self._starting_state_dict)
if self._formatting:
formatter = get_formatter(self._formatting.format_type, features=self.features)
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
)
else:
format_dict = None
if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"):
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
iterator = _convert_to_arrow(ex_iterable, batch_size=self.batch_size) # changed here
if self.batch_size > 1: # changed here until end of file
for key, pa_table in iterator:
yield formatter.format_batch(pa_table)
return
else:
for key, pa_table in iterator:
yield formatter.format_row(pa_table)
return
iterator = iter(ex_iterable)
if self.batch_size > 1:
for key, example in iterator:
# If batched, first build the batch
examples = [example] + [example for key, example in islice(iterator, self.batch_size - 1)]
if self.drop_last_batch and len(examples) < self.batch_size: # ignore last batch
return
batch = _examples_to_batch(examples)
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_batch`.
batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id)
yield format_dict(batch) if format_dict else batch
else:
for key, example in ex_iterable:
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
example = _apply_feature_types_on_example(
example, self.features, token_per_repo_id=self._token_per_repo_id
)
yield format_dict(example) if format_dict else example
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards."
)
else:
logger.debug(
f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})."
)
For anyone wanting to try it you can patch it into datasets by overwriting the function via setattr(datasets.IterableDataset, '_iter_pytorch', _iter_pytorch)
I don't really know what most of the rest of the code is doing so no idea if this is a valid fix or not, but it seems to work for me. Example of running the fix:
from torch.utils.data.dataloader import DataLoader
c4.batch_size = 512**2 # set batch size here
dataloader = Dataloader(c4, batch_size=None) # use custom batching from IterableDataset
iterator = iter(dataloader)
for i in range(5):
start_time=timer()
next(iterator) #get the batch
end_time = timer()
print(f"time for one batch: {end_time-start_time}")
I now get time for one batch: 0.6047679269686341 time for one batch: 0.000248616561293602 time for one batch: 0.00017435848712921143 time for one batch: 0.00015910807996988297 time for one batch: 0.00015317369252443314
I love the datasets library and it would be great if iterating with large batch sizes would be supported directly, either with a similar fix to mine or in some other way :)
Hi @taczin , thanks for reporting !
Indeed the IterableDataset.batch() implementation is quite naive is manipulates python objects:
https://github.com/huggingface/datasets/blob/d37ed46ebf45981131bd3678173dbb4b7e2b2f1a/src/datasets/iterable_dataset.py#L3026-L3029
However it can be much faster if it can be applied on the Arrow data, maybe using something like this (untested)
def batch_fn(unbatched):
return {k: [v] for k, v in unbatched.items()}
def batch_fn_arrow(unbatched_pa_table):
offsets = pa.array([0, len(unbatched_pa_table)])
return pa.Table.from_arrays([
pa.ListArray.from_arrays(offsets, unbatched_pa_table[k])
for k in unbatched_pa_table.column_names
], unbatched_pa_table.column_names)
if self._ex_iterabe.iter_arrow:
return self.with_format("arrow").map(
batch_fn_arrow, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch
).with_format(self._formatting.format_type if self._formatting else None)
else:
return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)
Hi @lhoestq , thanks for your answer. I was wondering: is there a reason why the internal call to ex_iterable = self._prepare_ex_iterable_for_iteration() in the IterableDataset code does not pass the batch size even though it could? If not passed the default of 1 is used, leading to the observed slow loading.
After calling .batch(), _prepare_ex_iterable_for_iteration should use batch_size=1 since now each row in the dataset is actually a batch of the original dataset.