Unable to get mid-epoch resumption working
I am using the Streaming in conjunction with pytorch lightning, but loading my dataloaders (including state resumption) separate of the pl logic.
Within my pl.Trainer I have the following to include the training dataloader state dict in the checkpoint:
def on_save_checkpoint(self, checkpoint):
dataloader = self.trainer.train_dataloader
if isinstance(dataloader, list):
dataloader = dataloader[0]
checkpoint['dataloader_state'] = dataloader.state_dict()
And then in my dataset class, I implement the logic to load this checkpoint, and specifically get the dataloader state (if resuming). As an example, this is the dict I extract:
self.dataloader_state = {
'epoch': 0,
'initial_physical_nodes': 1,
'num_canonical_nodes': 1,
'sample_in_epoch': 100,
'shuffle_seed': 9176
}
This is from a checkpoint saved at step 25 w/ a batch size of 4; so far so good. Now I load my dataset + dataloader (same as before), and instantiate from the state dict:
train_dataloader.load_state_dict(self.dataloader_state)
And for debugging purposes:
for i, batch in enumerate(train_dataloader):
print(i)
And the prints start from 0. If I understand correctly, with resumption this should not be the case, right? I ask because when involved in my actual pl lightning code, it does not seem to be resuming from the correct step either.