streaming icon indicating copy to clipboard operation
streaming copied to clipboard

Unable to get mid-epoch resumption working

Open JLenzy opened this issue 11 months ago • 0 comments

I am using the Streaming in conjunction with pytorch lightning, but loading my dataloaders (including state resumption) separate of the pl logic.

Within my pl.Trainer I have the following to include the training dataloader state dict in the checkpoint:

  def on_save_checkpoint(self, checkpoint):
      dataloader = self.trainer.train_dataloader
      if isinstance(dataloader, list):
          dataloader = dataloader[0]
      checkpoint['dataloader_state'] = dataloader.state_dict()

And then in my dataset class, I implement the logic to load this checkpoint, and specifically get the dataloader state (if resuming). As an example, this is the dict I extract:

self.dataloader_state = {
'epoch': 0, 
'initial_physical_nodes': 1, 
'num_canonical_nodes': 1, 
'sample_in_epoch': 100, 
'shuffle_seed': 9176
}

This is from a checkpoint saved at step 25 w/ a batch size of 4; so far so good. Now I load my dataset + dataloader (same as before), and instantiate from the state dict:

train_dataloader.load_state_dict(self.dataloader_state)

And for debugging purposes:

        for i, batch in enumerate(train_dataloader):
            print(i)

And the prints start from 0. If I understand correctly, with resumption this should not be the case, right? I ask because when involved in my actual pl lightning code, it does not seem to be resuming from the correct step either.

JLenzy avatar Feb 25 '25 20:02 JLenzy