Skip to main content

utils.checkpointer.checkpoint_utils

Checkpointer class and utility functions.

Checkpoint Objects#

@dataclass
class Checkpoint()

Checkpoint data class that holds the states for module_interface, trainer and trainer_backend

DefaultCheckpointerArguments Objects#

@dataclass
class DefaultCheckpointerArguments()

Default Checkpointer Arguments.

Arguments:

  • checkpoint bool - Flag indicating whether to checkpoint model when save() is called. Other conditions are implemented within save(), allowing this method to always be called within training loops and abstracting the checkpointing logic out of Trainer and implemented in this class.
  • delete_existing_checkpoints bool - Flag indicating whether to delete checkpoints under save_dir before training. New checkpoints are saved regardless.
  • period int - Period of index at which to checkpoint model. Evaluates index % period == 0. This function is called with index set to the epoch, and thus checkpoints every "period" number of epochs. The last epoch is always checkpointed regardless.
  • save_dir str - Path to directory where checkpoints are to be stored. Creates folder if it does not exist.
  • model_state_save_dir str - Path to directory where checkpointed models are to be stored. Creates folder if it does not exist.
  • load_dir str - Path to directory where checkpoints are to be loaded from. If not set, will not attempt to load a checkpoint. If load_filename is set, will search for this filename within the directory to load it. If load_filename is not set, will load the file via get_latest_file().
  • load_filename str - Filename of checkpoint to load under load_dir, overrides automatic loading via get_latest_file().
  • file_prefix str - Prefix of the checkpoint filename. Final filename to save will be {fileprefix}{index}.{fileext}, or in the case of saving with save_model(), {file_prefix}_mode{index}.{file_ext}.
  • file_ext str - File extension for the checkpoint filename when saving and when searching under load_dir for loading via get_latest_file(). When cleaning save_dir via delete_existing_checkpoints=True, only files with this extension are considered.
  • log_level str - Logging level for checkpointer module (Default: 'INFO').

AbstractCheckpointer Objects#

class AbstractCheckpointer(ABC)

Abstract class for a checkpointer.

To create a custom checkpointer, users must implement the abstract methods of this class and pass along an instance to ModuleInterface. Custom checkpointers can be used at other stages of the training lifecycle via callbacks.

save#

@abstractmethod
def save(checkpoint_state: Checkpoint, index: int,
force: Optional[bool]) -> None

Creates a checkpoint by saving a Checkpoint dataclass containing any relevant states.

Arguments:

  • checkpoint_state Checkpoint - Checkpointed states.
  • index int - Using epoch as index is suggested.
  • force bool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint models after the last epoch.

save_model#

def save_model(model_state: Dict, index: int) -> None

Creates a model checkpoint by saving model state.

Arguments:

  • model_state Dict - Model state as provided by ModuleInterface.
  • index int - Number to use to create a unique filename. Using epoch as index is suggested.

load#

@abstractmethod
def load() -> Checkpoint

Load and return a checkpointed file.

Implements logic to load a checkpointed file as configured via args used when constructing the checkpointer object. Always called upon initialization of Trainer.

Returns:

  • Checkpoint - Checkpointed states.

DefaultCheckpointer Objects#

class DefaultCheckpointer(AbstractCheckpointer)

Default checkpointer implementation, implements AbstractCheckpointer and contains a few helper functions for managing checkpointed files.

Must be initialized with DefaultCheckpointerArguments.

__init__#

def __init__(args: DefaultCheckpointerArguments)

Initialize checkpointer and delete existing checkpointed files under save_dir if delete_existing_checkpoints is set to True.

save#

def save(checkpoint_state: Checkpoint, index: int, force=False) -> str

Creates a checkpoint by saving a Checkpoint dataclass containing any relevant states as a python Dict.

Evaluates conditions and, if met, saves a provided dataclass which should contain any states that users require to save as part of a checkpoint under args.save_dir. An additional index argument is required to create a unique name for the file to be saved. The optional force flag will disregard conditions other than the checkpoint flag that enables this behavior. The condition for saving with DefaultCheckpointer is index being a multiple of the args.period.

Arguments:

  • checkpoint_state Checkpoint - Checkpointed states.
  • index int - Number to use to create a unique filename and evaluate conditions for checkpointing. Using epoch as index is suggested.
  • force bool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint states after the last epoch.

Returns:

  • str - Path to checkpointed file.

save_model#

def save_model(model_state: Dict, index: int) -> str

Checkpoints a model state, leveraging torch.save().

Evaluates if checkpointing is enabled and if a model save directory has been set, and saves a provided model state. An additional index argument is required to create a unique name for the file to be saved.

Arguments:

  • model_state Dict - Model state as provided by ModuleInterface.
  • index int - Number to use to create a unique filename. Using epoch as index is suggested.

Returns:

  • str - Path to checkpointed file.

load#

def load() -> Checkpoint

Attempt to load and return a checkpointed file leveraging torch.load(). The checkpoined file is assumed to be created with save() and thus be a python Dict.

This method is always called upon initialization of Trainer. Searches for and attempts to load a checkpointed file based on args. If no load_dir is set, returns None. If a load_dir and load_filename have been set, the file "load_filename" under load_dir is directly loaded (the filename must include extension). If only load_dir is set, get_latest_file() is called to seach the folder for the file with the largest integer (index) in its filename, and returns that path for loading.

Returns:

  • Checkpoint - Checkpointed states.

get_latest_file#

@staticmethod
def get_latest_file(load_dir: str,
file_prefix: str,
file_ext: str = 'pt',
logger=getlogger(__name__)) -> str

Get the path to the last checkpointed file.

Find and return the path of the file with greatest number of completed epochs under dirpath (recursive search) for a given file prefix, and optionally file extension.

Arguments:

  • load_dir str - Directory under which to search for checkpointed files.
  • file_prefix str - Prefix to match for when searching for candidate files.
  • file_ext str, optional - File extension to consider when searching.

Returns:

  • str - Path to latest checkpointed file.

check_mk_dir#

def check_mk_dir(dirpath: str) -> None

Check if the path exists, and if it doesn't creates it.

Arguments:

  • dirpath str - Directory under which to search for checkpointed files.