streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Support for on-the-fly filtering

Open ColinToft opened this issue 1 year ago • 1 comments

🚀 Feature Request

My company is currently using Mosaic streaming for our training, and we would like to implement on-the-fly filtering based on conditions loaded from a config at runtime.

Motivation

We train models using many different datasets and model configurations and would like to be able to adjust these datasets as necessary, filtering at runtime based on different properties as opposed to spending time in advance making many different variations of the same dataset. I believe that many other users may benefit from this feature, so it may be worth taking the time to implement, but if not I would appreciate guidance on how I could accomplish this behavior locally.

Implementation

This presents two challenges:

  1. Dynamically filtering the stream to only yield some items to the model
  2. Handle correct resumption from checkpoints/state_dict

The first we have tackled locally by creating our own FilterDataset that takes in a StreamingDataset as a parameter. By overriding __iter__ we can selectively yield items (from self.original_dataset.__iter__()) that pass the given filtering criteria.

However, the second issue is more challenging. From my understanding, Mosaic will save samples_in_epoch which is the number of samples yielded in total amongst all workers. However, due to filtering, each worker will have gotten further along than samples_in_epoch since it will have filtered out some samples along the way. This causes the worker to "peek ahead" at samples that the streaming library doesn't expect it to have seen yet. Ultimately, we found that this causes our model to see some samples twice when we save and then resume from a checkpoint.

(If that was confusing I can share a small reproducible example.)

Additional context

This doesn't necessarily have to be implemented as a feature; any help getting this working myself would be appreciated as well!

ColinToft avatar Oct 09 '24 22:10 ColinToft

you can filtering data in the Streaming**DataLoader**.__iter__() instead of StreamingDataset.__iter__(). Because the state_dict depends on the num_sample_yielded in StreamingDataLoader, you need to calculate the num_sample_yielded firstly, then filter the samples you don't need.

YaoyaoChang avatar Dec 26 '24 02:12 YaoyaoChang