utils.checkpointer.checkpoint_utils
Checkpointer class and utility functions.
Checkpoint Objects#
Checkpoint data class that holds the states for module_interface, trainer and trainer_backend
DefaultCheckpointerArguments Objects#
Default Checkpointer Arguments.
Arguments:
checkpointbool - Flag indicating whether to checkpoint model when save() is called. Other conditions are implemented within save(), allowing this method to always be called within training loops and abstracting the checkpointing logic out of Trainer and implemented in this class.delete_existing_checkpointsbool - Flag indicating whether to delete checkpoints under save_dir before training. New checkpoints are saved regardless.periodint - Period of index at which to checkpoint model. Evaluates index % period == 0. This function is called with index set to the epoch, and thus checkpoints every "period" number of epochs. The last epoch is always checkpointed regardless.save_dirstr - Path to directory where checkpoints are to be stored. Creates folder if it does not exist.model_state_save_dirstr - Path to directory where checkpointed models are to be stored. Creates folder if it does not exist.load_dirstr - Path to directory where checkpoints are to be loaded from. If not set, will not attempt to load a checkpoint. If load_filename is set, will search for this filename within the directory to load it. If load_filename is not set, will load the file via get_latest_file().load_filenamestr - Filename of checkpoint to load under load_dir, overrides automatic loading via get_latest_file().file_prefixstr - Prefix of the checkpoint filename. Final filename to save will be {fileprefix}{index}.{fileext}, or in the case of saving with save_model(), {file_prefix}_mode{index}.{file_ext}.file_extstr - File extension for the checkpoint filename when saving and when searching under load_dir for loading via get_latest_file(). When cleaning save_dir via delete_existing_checkpoints=True, only files with this extension are considered.log_levelstr - Logging level for checkpointer module (Default: 'INFO').
AbstractCheckpointer Objects#
Abstract class for a checkpointer.
To create a custom checkpointer, users must implement the abstract methods of this class and pass along an instance to ModuleInterface. Custom checkpointers can be used at other stages of the training lifecycle via callbacks.
save#
Creates a checkpoint by saving a Checkpoint dataclass containing any relevant states.
Arguments:
checkpoint_stateCheckpoint - Checkpointed states.indexint - Using epoch as index is suggested.forcebool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint models after the last epoch.
save_model#
Creates a model checkpoint by saving model state.
Arguments:
model_stateDict - Model state as provided by ModuleInterface.indexint - Number to use to create a unique filename. Using epoch as index is suggested.
load#
Load and return a checkpointed file.
Implements logic to load a checkpointed file as configured via args used when constructing the checkpointer object. Always called upon initialization of Trainer.
Returns:
Checkpoint- Checkpointed states.
DefaultCheckpointer Objects#
Default checkpointer implementation, implements AbstractCheckpointer and contains a few helper functions for managing checkpointed files.
Must be initialized with DefaultCheckpointerArguments.
__init__#
Initialize checkpointer and delete existing checkpointed files under save_dir if delete_existing_checkpoints is set to True.
save#
Creates a checkpoint by saving a Checkpoint dataclass containing any relevant states as a python Dict.
Evaluates conditions and, if met, saves a provided dataclass which should contain any states that users require to save as part of a checkpoint under args.save_dir. An additional index argument is required to create a unique name for the file to be saved. The optional force flag will disregard conditions other than the checkpoint flag that enables this behavior. The condition for saving with DefaultCheckpointer is index being a multiple of the args.period.
Arguments:
checkpoint_stateCheckpoint - Checkpointed states.indexint - Number to use to create a unique filename and evaluate conditions for checkpointing. Using epoch as index is suggested.forcebool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint states after the last epoch.
Returns:
str- Path to checkpointed file.
save_model#
Checkpoints a model state, leveraging torch.save().
Evaluates if checkpointing is enabled and if a model save directory has been set, and saves a provided model state. An additional index argument is required to create a unique name for the file to be saved.
Arguments:
model_stateDict - Model state as provided by ModuleInterface.indexint - Number to use to create a unique filename. Using epoch as index is suggested.
Returns:
str- Path to checkpointed file.
load#
Attempt to load and return a checkpointed file leveraging torch.load(). The checkpoined file is assumed to be created with save() and thus be a python Dict.
This method is always called upon initialization of Trainer. Searches for and attempts to load a checkpointed file based on args. If no load_dir is set, returns None. If a load_dir and load_filename have been set, the file "load_filename" under load_dir is directly loaded (the filename must include extension). If only load_dir is set, get_latest_file() is called to seach the folder for the file with the largest integer (index) in its filename, and returns that path for loading.
Returns:
Checkpoint- Checkpointed states.
get_latest_file#
Get the path to the last checkpointed file.
Find and return the path of the file with greatest number of completed epochs under dirpath (recursive search) for a given file prefix, and optionally file extension.
Arguments:
load_dirstr - Directory under which to search for checkpointed files.file_prefixstr - Prefix to match for when searching for candidate files.file_extstr, optional - File extension to consider when searching.
Returns:
str- Path to latest checkpointed file.
check_mk_dir#
Check if the path exists, and if it doesn't creates it.
Arguments:
dirpathstr - Directory under which to search for checkpointed files.