streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Advice Needed: handling significant amount of streams

Open suessmann opened this issue 2 years ago • 3 comments

Hiya!

I have approximately 1k data streams, each containing pickled numpy arrays. When data is loaded, I need to sample a subsequence from it, so my dataloader looks like this:

class WatermazeDatasetNew(StreamingDataset):
    def __init__(
        self,
        streams,
        seq_len,
        batch_size,
    ):
        super().__init__(
            streams=streams,
            batch_size=batch_size,
        )
        self.seq_len = seq_len

    def __getitem__(self, idx: int):
        obj = super().__getitem__(idx)
        idx_traj = np.random.randint(0, len(obj["img"]) - self.seq_len)
        state = obj["img"][idx_traj : idx_traj + self.seq_len] / 255
        action = obj["action"][idx_traj : idx_traj + self.seq_len]
        reward = obj["reward"][idx_traj : idx_traj + self.seq_len]
        return state, action, reward

It does what it should, but data loading takes too much time. The data cannot be put into RAM altogether, since its capacity is around 100GiB, while data itself is ~150GiB. I wonder if there is a way to keep the .mds file open in RAM and close it only when RAM is close to full. Something like cache_limit, but for RAM. I tried to enlarge predownload and epoch_size, but it didn't help much. Timings I got were like this:

Spoiler
GETTING DATA TOOK: 66.773770
forward took 2.23513
GETTING DATA TOOK: 1.681198
forward took 0.01434
GETTING DATA TOOK: 1.234387
forward took 0.00848
GETTING DATA TOOK: 1.397633
forward took 0.00872
GETTING DATA TOOK: 1.306774
forward took 0.00877
GETTING DATA TOOK: 1.275940
forward took 0.01117
GETTING DATA TOOK: 1.560843
forward took 0.00788
GETTING DATA TOOK: 1.619301
forward took 0.01183
GETTING DATA TOOK: 42.097745
forward took 0.00698
GETTING DATA TOOK: 10.102559
forward took 0.01008
GETTING DATA TOOK: 1.982731
forward took 0.01147
GETTING DATA TOOK: 1.585924
forward took 0.00956
GETTING DATA TOOK: 1.719182
forward took 0.01002
GETTING DATA TOOK: 1.663528
forward took 0.00875
GETTING DATA TOOK: 1.777168
forward took 0.00906
GETTING DATA TOOK: 1.621303
forward took 0.01053
GETTING DATA TOOK: 49.404249

So clearly every now and then the data files are closed and reopened again.

Is there any simple solution, am I missing something?

suessmann avatar Dec 20 '23 18:12 suessmann

Hey @suessmann, thanks for filing an issue. A few questions/suggestions

  • Do you know if GETTING DATA TOOK is collecting the time for shard downloading + sample reading or just sample reading?
  • What are the streaming hyperparameters that you are providing? And how many physical nodes?
  • Have you tried a streaming simulator to find the right set of hyperparams?
  • Does not relate to the above issue, but do you actually have 1K different data? Are you aware of merge_index() method? If you created the MDS shard file using multiple processes, each process creates its own directory and index.json file; you can use the merge_index() method to merge all the sub-datasets. I would recommend doing this per dataset.

karan6181 avatar Jan 02 '24 18:01 karan6181

@suessmann Gentle reminder ^. Thank You!

karan6181 avatar Feb 28 '24 19:02 karan6181

@suessmann What is the column format of your data? Streaming Dataset supports ndarray serialization and deserialization natively. Streaming Dataset support the MDS ndarray serialization in three different ways:

  1. dynamic shape and dtype
  2. dynamic shape but fixed dtype
  3. fixed shape and dtype

karan6181 avatar Apr 23 '24 11:04 karan6181