tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

Feedback about Asynchronous Saving with Distributed Checkpoint (DCP)

Open lebrice opened this issue 4 months ago • 2 comments

Hey there! Little nitpick about the last block of this docs page: https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html

The checkpoint_future variable is never written to in the last block. Perhaps the intent was to have this instead?

checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

cc @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn @ekr0 @haochengsong @Saiteja64

lebrice avatar Sep 23 '25 13:09 lebrice

Also, I thought (based on this blog https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250) that one requirement for lowering the overhead of async DCP was that you had to make sure the checkpointing work was overlapping the forward and backward pass, and was completed (and the cuda stream synchronized) before the next optimizer.step(), no?

Here you're making sure that the checkpointing is done after the next call to dcp.async_save, but shouldn't this block come before the optimizer step modifies the weights in-place? And shouldn't there be some kind of cuda stream synchronization? Or is that hidden in the StorageWriter?

    writer = StorageWriter(cache_staged_state_dict=True, path=CHECKPOINT_DIR)
    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        
        # Shouldn't this be done here, before the optimizer step?
        if checkpoint_future is not None:
            # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
            checkpoint_future.result()

        optimizer.step()
        
        state_dict = { "app": AppState(model, optimizer) }
        # Missing this assignment I think
        checkpoint_future = dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()

lebrice avatar Sep 23 '25 13:09 lebrice

@lebrice looking at the file again, I can see that line 266 already has the assignment checkpoint_future = dcp.async_save(...)

patrocinio avatar Dec 08 '25 23:12 patrocinio