streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Facing issue with resuming training for saved dataset state (>1 epoch)

Open rodosingh opened this issue 1 year ago • 2 comments

Environment

  • OS: [Ubuntu 22.04]
  • Hardware (GPU, or instance type): [MI300, ROCm==6.1.0]
  • NUM_NODES: [2]
  • GPUs/NODE: [8]

Context

  • While trying to resume my training from state (both checkpoint and S3 dataset state) beyond one epoch, throws error to start from the state where first epoch ended.
  • Please see the error below. And also I'm using choose and repeat functionality of StreamingDataset class to downsample & upsample, respectively as per requirement (the demo .yaml file is also attached here).
0: [rank0]: Traceback (most recent call last):
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/llava/train/train_mem.py", line 4, in <module>
0: [rank0]:     train()
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/llava/train/train.py", line 2034, in train
0: [rank0]:     trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
0: [rank0]:     return inner_training_loop(
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/transformers/trainer.py", line 2246, in _inner_training_loop
0: [rank0]:     for step, inputs in enumerate(epoch_iterator):
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/accelerate/data_loader.py", line 552, in __iter__
0: [rank0]:     current_batch = next(dataloader_iter)
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
0: [rank0]:     data = self._next_data()
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
0: [rank0]:     return self._process_data(data)
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
0: [rank0]:     data.reraise()
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_utils.py", line 706, in reraise
0: [rank0]:     raise exception
0: [rank0]: ValueError: Caught ValueError in DataLoader worker process 0.
0: [rank0]: Original Traceback (most recent call last):
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
0: [rank0]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
0: [rank0]:   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
0: [rank0]:     data.append(next(self.dataset_iter))
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/streaming/streaming/base/dataset.py", line 1501, in __iter__
0: [rank0]:     sample_ids = self._get_work(epoch, sample_in_epoch)
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/streaming/streaming/base/dataset.py", line 1046, in _get_work
0: [rank0]:     epoch_sample_ids = generate_work(self.batching_method, self, p_world, epoch,
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/streaming/streaming/base/batching/__init__.py", line 45, in generate_work
0: [rank0]:     return get(dataset, world, epoch, sample_in_epoch)
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/streaming/streaming/base/batching/random.py", line 57, in generate_work_random_batching
0: [rank0]:     big_ids = get_partitions(dataset.partition_algo, dataset.epoch_size,
0: [rank0]:   File "/home/<user>/LLaVA-NeXT/streaming/streaming/base/partition/__init__.py", line 69, in get_partitions
0: [rank0]:     raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
0: [rank0]: ValueError: Resuming further into the dataset (7824000) than it has samples (6468555)

Corresponding Yaml file that specifies path to S3 shards:

datasets:
- shard_path: 's3://object/data/shards/LLaVA_Stage2/VQA-RAD/'
- shard_path: 's3://object/data/shards/LLaVA_Stage2/infographic_vqa/'
- shard_path: 's3://object/data/shards/LLaVA_Stage2/iconqa/'
  choose: 1365
- shard_path: 's3://object/data/shards/LLaVA_Stage2/TabMWP/'  
- shard_path: 's3://object/data/shards/LLaVA_Stage2/scienceqa_nona_context/'  
  choose: 960
- shard_path: 's3://object/data/shards/LLaVA_Stage2/scienceqa_nona_context/'  
- shard_path: 's3://object/data/shards/LLaVA_Stage2/scienceqa/' 
  repeat: 2
- shard_path: 's3://object/data/shards/LLaVA_Stage2/ureader_kg/' 
- shard_path: 's3://object/data/shards/LLaVA_Stage2/aokvqa/' 
- shard_path: 's3://object/data/shards/LLaVA_Stage2/k12_printing/' 
  choose: 2566

Can anyone please take a look at this issue? Any further info, please let me know in the thread.

Thanks for your help!

rodosingh avatar Jan 28 '25 05:01 rodosingh

just to confirm, are you only wanting to repeat shard_path: 's3://object/data/shards/LLaVA_Stage2/scienceqa/' twice?

ethantang-db avatar Jan 29 '25 20:01 ethantang-db

Thanks @ethantang-db 🙌.

Yes, there are multiple such datasets, but just for illustration I have provided this example of repeat and also some for choose.

rodosingh avatar Jan 29 '25 21:01 rodosingh