ART icon indicating copy to clipboard operation
ART copied to clipboard

Feature request - load best checkpoint

Open giladfrid009 opened this issue 5 months ago • 2 comments

Currently there is a function TrainableModel.delete_checkpoints(best_checkpoint_metric) which removes all checkpoints except the best and the latest.

Unfortunately, there is no straight-forward way to load the weights according to the best checkpoint. It would be nice if such function existed.

Example signature:


class TrainableModel:
    def load_checkpoint(which: int | Literal["best"] | Literal["latest"] = "latest", best_checkpoint_metric: str = "val/reward"):
        """
        Args: 
          which (int | Literal["best"] | Literal["latest"]) - The type of checkpoint to load.
            - "best" loads the best checkpoint according to the `best_checkpoint_metric"
            - "latest" loads the latest checkpoint available
            - integer value determines the step number of the checkpoint.
          best_checkpoint_metric (str) - the name of the metric determining which checkpoint is best.
          """
          ...

giladfrid009 avatar Aug 05 '25 12:08 giladfrid009