ART
ART copied to clipboard
Feature request - load best checkpoint
Currently there is a function TrainableModel.delete_checkpoints(best_checkpoint_metric) which removes all checkpoints except the best and the latest.
Unfortunately, there is no straight-forward way to load the weights according to the best checkpoint. It would be nice if such function existed.
Example signature:
class TrainableModel:
def load_checkpoint(which: int | Literal["best"] | Literal["latest"] = "latest", best_checkpoint_metric: str = "val/reward"):
"""
Args:
which (int | Literal["best"] | Literal["latest"]) - The type of checkpoint to load.
- "best" loads the best checkpoint according to the `best_checkpoint_metric"
- "latest" loads the latest checkpoint available
- integer value determines the step number of the checkpoint.
best_checkpoint_metric (str) - the name of the metric determining which checkpoint is best.
"""
...