fms-fsdp icon indicating copy to clipboard operation
fms-fsdp copied to clipboard

[Checkpoint] Rewrite Checkpointing

Open lchu6 opened this issue 10 months ago • 1 comments

Background

It is time to rewrite current checkpointing for a few reasons:

  1. Functionality support: move from FSDPv1 logic to DTensor logic, so we support DCP for FSDPv2, TP, etc.
  2. Cleanup: current checkpointing has served many generations of fms-fsdp, during which period many of developed features are either never used, or no longer used. We can leverage this chance to build it from scratch.

Proposal

  1. move from FSDPv1 checkpointing to DTensor checkpointing
  2. remove metadata.pth file saving and loading, move all metadata to state_dict as train_state, so state_dict is uniformed as:
state_dict = {
            "model_state": model_state,
            "optim_state": optim_state,
            "train_state": train_state,
            }
  1. simplify apis:

old: ~~def init(self, ckpdir, n_to_save, parallel_mode, rank, local_rank, report_fn=None, model_auto_placement=False)~~ new: def init(self, path)

old: ~~def load(self, model, optimizer, dataloader, path="", reset_stepcount=False, strict=True, is_compiled=False)~~ new: def load(self, model, optimizer, train_state)

old: ~~def save(self, step, model, optimizer, dataloader, **kwargs)~~ new: def save(self, step, model, optimizer, train_state)

Detailed cleanups

  1. remove dataloader support, as data loader was moved outside checkpoint util long time ago.
  2. remove support for loading from a single ckpt and saving a single ckpt, for the following reasons:
    1. we barely/never used them.
    2. conversion, when needed, should be done as a post processing.
  3. remove support for ckpt clean up with max_ckpt_count, as this is not used and we are saving all ckpts with the smart gpfs solution.
  4. remove "get_oldest", as it is never used.
  5. remove the need for special taken-care for HSDP, as FSDPv2 no longer has this issue
  6. remove parallel_mode arg, for the same reason as above
  7. remove report_fn arg. default is good enough.
  8. remove model_auto_placement arg.

Many of these dropped features can be left there with no harm, but it is better to start from a cleaner version, and we can always add them back if necessary.

lchu6 avatar Apr 08 '25 16:04 lchu6

@daviswer to review.

cc. @raghukiran1224

lchu6 avatar Apr 08 '25 16:04 lchu6