datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Batched IterableDataset

Open lneukom opened this issue 2 years ago • 5 comments

Feature request

Hi,

could you add an implementation of a batched IterableDataset. It already support an option to do batch iteration via .iter(batch_size=...) but this cannot be used in combination with a torch DataLoader since it just returns an iterator.

Motivation

The current implementation loads each element of a batch individually which can be very slow in cases of a big batch_size. I did some experiments here and using a batched iteration would speed up data loading significantly.

Your contribution

N/A

lneukom avatar Oct 05 '23 11:10 lneukom

This is exactly what I was looking for. It would also be very useful for me :-)

VascoSch92 avatar Oct 05 '23 11:10 VascoSch92

This issue is really smashing the selling point of HF datasets... The only workaround I've found so far is to create a customized IterableDataloader which improves the loading speed to some extent.

For example I've a HF dataset dt_train with len(dt_train) == 1M. Using plain DataLoader is extremely slow:

%%time
dl_train = DataLoader(dt_train, batch_size=128, shuffle = True)
for batch in dl_train:
    pass
CPU times: user 24min 35s, sys: 704 ms, total: 24min 36s
Wall time: 24min 37s

And DataLoader works even worse with HF's iterable_dataset:

%%time
dt_train_ = dt_train.with_format(None).to_iterable_dataset(num_shards=64).shuffle(buffer_size=10_000)
dl_train = DataLoader(dt_train_, batch_size=128)
for batch in dl_train:
    pass
CPU times: user 1h 6min 2s, sys: 4.28 s, total: 1h 6min 6s
Wall time: 1h 7min 53s

Workaround by running a customized wrapper:

%%time
from torch.utils.data import DataLoader, IterableDataset

class Dataset2Iterable(IterableDataset):
    """
    Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
    """
    def __init__(self, dataset, batch_size=1, shuffle=True):
        super(Dataset2Iterable).__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle: self.dataset.shuffle()
        return self.dataset.iter(batch_size=self.batch_size)

dl_train = DataLoader(Dataset2Iterable(dt_train, batch_size = 128), batch_size=1, num_workers=0)
for n in range(2):
    for batch in dl_train:
        pass

The speed still is slower than using tensorflow's loader but improved a lot than previous code:

CPU times: user 4min 18s, sys: 0 ns, total: 4min 18s
Wall time: 4min 20s

Note that the way I implemented Dataset2Iterable will only work with num_workers=0.

zhh210 avatar Jun 03 '24 21:06 zhh210

I can confirm that @zhh210's solution works with num_workers=0. However, for my use case, this was still slower than tokenizing on the fly through a collator and leveraging multiple workers in the dataloder.

@lhoestq I think this is an important use case (e.g., streaming from a large dataset, online or stored on disk). What do you think might be the best solution to move forward?

jaketae avatar Jul 05 '24 06:07 jaketae

I guess it can be implemented using a batched.map() under the hood that returns a single item containing the input batch.

In the meantime you can use this:

def batch(unbatched: dict[str, list]) -> dict[str, list]:
    return {k: [v] for k, v in unbatched}

batched_dataset = dataset.map(batch, batched=True, batch_size=batch_size)

Though it would be great to have a .batch() method indeed, I'd be happy to help with anyone wants to open a PR

lhoestq avatar Jul 08 '24 09:07 lhoestq

If no one else is planning to work on this, I can take it on. I'll wait until next week, and if no one has started a PR by then, I'll go ahead and open one.

lappemic avatar Jul 08 '24 11:07 lappemic

It looks like the implementation of IterableDataset is still using a hardcoded batch size of 1. For example in line 2063 in /datsets/src/datasets/iterable_dataset.py. Iterating over IterableDataset with large batch sizes therefore remains slow, even when using batch(). I guess then the data are not being read from one contiguous chunk of memory. Instead every example is retrieved one by one, leading to long dataloading times. As a minimal example: Load c4 dataset and iterate over it with a large batch size.

import datasets
from timeit import default_timer as timer
c4 = datasets.load_dataset("allenai/c4", "en", streaming=True, split="train")
c4_batched = c4.batch(512**2) # use large batch size
iterator = iter(c4_batched)
for i in range(5):
    start_time=timer()
    next(iterator) # get next batch
    end_time = timer()
    print(f"time for one batch: {end_time-start_time}")

This results in the following output for me: time for one batch: 12.615376660600305 time for one batch: 13.011422813870013 time for one batch: 14.157325950451195 time for one batch: 14.225894245319068 time for one batch: 13.898222777992487

Because I want to use my IterableDataset with the pytorch dataloader I rewrote the __iter_pytorch__ and the __iter__ functions like so and am getting much faster dataloading times. I marked the lines I changed with "# changed here":

from datasets.iterable_dataset import _convert_to_arrow
from datasets.formatting import TensorFormatter, get_formatter
from datasets.features.features import cast_to_python_objects
import sys
import fsspec.asyn
from itertools import islice
from datasets.utils.logging import get_logger
from datasets.iterable_dataset import _examples_to_batch, _apply_feature_types_on_batch, _apply_feature_types_on_example

logger = get_logger(__name__)

def __iter__(self):
    if "torch" in sys.modules:
        import torch.utils.data

        worker_info = torch.utils.data.get_worker_info()
        if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None:
            # We're a torch.utils.data.IterableDataset in a PyTorch worker process
            yield from self._iter_pytorch()
            return

    ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=self.batch_size, drop_last_batch=self.drop_last_batch) # changed here
    if self._formatting:
        formatter = get_formatter(self._formatting.format_type, features=self.features)
        format_dict = (
            formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
        )
    else:
        format_dict = None

    if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"):
        if ex_iterable.iter_arrow:
            iterator = ex_iterable.iter_arrow()
        else:
            iterator = _convert_to_arrow(ex_iterable, batch_size=self.batch_size) # changed here
        for key, pa_table in iterator:
            yield formatter.format_row(pa_table)
        return

    for key, example in ex_iterable:
        if self.features and not ex_iterable.is_typed:
            # `IterableDataset` automatically fills missing columns with None.
            # This is done with `_apply_feature_types_on_example`.
            example = _apply_feature_types_on_example(
                example, self.features, token_per_repo_id=self._token_per_repo_id
            )
        yield format_dict(example) if format_dict else example



def _iter_pytorch(self):
    ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=self.batch_size, drop_last_batch=self.drop_last_batch) # changed here
    # Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0)
    # See https://github.com/fsspec/gcsfs/issues/379
    fsspec.asyn.reset_lock()
    # check if there aren't too many workers
    import torch.utils.data

    worker_info = torch.utils.data.get_worker_info()
    if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers:
        logger.warning(
            f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). "
            f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers."
        )
        logger.info(
            f"To parallelize data loading, we give each process some shards (or data sources) to process. "
            f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. "
            f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}."
        )
    # split workload
    _log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
    shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers)
    if shards_indices:
        logger.debug(
            f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards."
        )
        ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers)
        self._state_dict = ex_iterable._init_state_dict()
        if self._starting_state_dict:
            ex_iterable.load_state_dict(self._starting_state_dict)

        if self._formatting:
            formatter = get_formatter(self._formatting.format_type, features=self.features)
            format_dict = (
                formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
            )
        else:
            format_dict = None

        if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"):
            if ex_iterable.iter_arrow:
                iterator = ex_iterable.iter_arrow()
            else:
                iterator = _convert_to_arrow(ex_iterable, batch_size=self.batch_size) # changed here
            if self.batch_size > 1: # changed here until end of file
                for key, pa_table in iterator:
                    yield formatter.format_batch(pa_table)
                return
            else:
                for key, pa_table in iterator:
                    yield formatter.format_row(pa_table)
                return

        iterator = iter(ex_iterable)
        if self.batch_size > 1:
            for key, example in iterator:
                    # If batched, first build the batch
                    examples = [example] + [example for key, example in islice(iterator, self.batch_size - 1)]
                    if self.drop_last_batch and len(examples) < self.batch_size:  # ignore last batch
                        return
                    batch = _examples_to_batch(examples)
                    if self.features and not ex_iterable.is_typed:
                        # `IterableDataset` automatically fills missing columns with None.
                        # This is done with `_apply_feature_types_on_batch`.
                        batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id)
                    yield format_dict(batch) if format_dict else batch
        else:
            for key, example in ex_iterable:
                if self.features and not ex_iterable.is_typed:
                    # `IterableDataset` automatically fills missing columns with None.
                    # This is done with `_apply_feature_types_on_example`.
                    example = _apply_feature_types_on_example(
                        example, self.features, token_per_repo_id=self._token_per_repo_id
                    )
                yield format_dict(example) if format_dict else example
        logger.debug(
            f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards."
        )
    else:
        logger.debug(
            f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})."
        )

For anyone wanting to try it you can patch it into datasets by overwriting the function via setattr(datasets.IterableDataset, '_iter_pytorch', _iter_pytorch)

I don't really know what most of the rest of the code is doing so no idea if this is a valid fix or not, but it seems to work for me. Example of running the fix:

from torch.utils.data.dataloader import DataLoader
c4.batch_size = 512**2 # set batch size here
dataloader = Dataloader(c4, batch_size=None) # use custom batching from IterableDataset
iterator = iter(dataloader)
for i in range(5):
    start_time=timer()
    next(iterator) #get the batch
    end_time = timer()
    print(f"time for one batch: {end_time-start_time}")

I now get time for one batch: 0.6047679269686341 time for one batch: 0.000248616561293602 time for one batch: 0.00017435848712921143 time for one batch: 0.00015910807996988297 time for one batch: 0.00015317369252443314

I love the datasets library and it would be great if iterating with large batch sizes would be supported directly, either with a similar fix to mine or in some other way :)

taczin avatar Nov 05 '24 11:11 taczin

Hi @taczin , thanks for reporting !

Indeed the IterableDataset.batch() implementation is quite naive is manipulates python objects:

https://github.com/huggingface/datasets/blob/d37ed46ebf45981131bd3678173dbb4b7e2b2f1a/src/datasets/iterable_dataset.py#L3026-L3029

However it can be much faster if it can be applied on the Arrow data, maybe using something like this (untested)

def batch_fn(unbatched): 
    return {k: [v] for k, v in unbatched.items()} 

def batch_fn_arrow(unbatched_pa_table): 
    offsets = pa.array([0, len(unbatched_pa_table)])
    return pa.Table.from_arrays([
        pa.ListArray.from_arrays(offsets, unbatched_pa_table[k])
        for k in unbatched_pa_table.column_names
    ], unbatched_pa_table.column_names)

if self._ex_iterabe.iter_arrow:
    return self.with_format("arrow").map(
        batch_fn_arrow, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch
    ).with_format(self._formatting.format_type if self._formatting else None)
else:
    return self.map(batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch)

lhoestq avatar Nov 05 '24 14:11 lhoestq

Hi @lhoestq , thanks for your answer. I was wondering: is there a reason why the internal call to ex_iterable = self._prepare_ex_iterable_for_iteration() in the IterableDataset code does not pass the batch size even though it could? If not passed the default of 1 is used, leading to the observed slow loading.

taczin avatar Nov 05 '24 15:11 taczin

After calling .batch(), _prepare_ex_iterable_for_iteration should use batch_size=1 since now each row in the dataset is actually a batch of the original dataset.

lhoestq avatar Nov 05 '24 15:11 lhoestq