utils.checkpointer.checkpoint_utils
Checkpointer class and utility functions.
#
Checkpoint ObjectsCheckpoint data class that holds the states for module_interface, trainer and trainer_backend
#
DefaultCheckpointerArguments ObjectsDefault Checkpointer Arguments.
Arguments:
checkpoint
bool - Flag indicating whether to checkpoint model when save() is called. Other conditions are implemented within save(), allowing this method to always be called within training loops and abstracting the checkpointing logic out of Trainer and implemented in this class.delete_existing_checkpoints
bool - Flag indicating whether to delete checkpoints under save_dir before training. New checkpoints are saved regardless.period
int - Period of index at which to checkpoint model. Evaluates index % period == 0. This function is called with index set to the epoch, and thus checkpoints every "period" number of epochs. The last epoch is always checkpointed regardless.save_dir
str - Path to directory where checkpoints are to be stored. Creates folder if it does not exist.model_state_save_dir
str - Path to directory where checkpointed models are to be stored. Creates folder if it does not exist.load_dir
str - Path to directory where checkpoints are to be loaded from. If not set, will not attempt to load a checkpoint. If load_filename is set, will search for this filename within the directory to load it. If load_filename is not set, will load the file via get_latest_file().load_filename
str - Filename of checkpoint to load under load_dir, overrides automatic loading via get_latest_file().file_prefix
str - Prefix of the checkpoint filename. Final filename to save will be {fileprefix}{index}.{fileext}, or in the case of saving with save_model(), {file_prefix}_mode{index}.{file_ext}.file_ext
str - File extension for the checkpoint filename when saving and when searching under load_dir for loading via get_latest_file(). When cleaning save_dir via delete_existing_checkpoints=True, only files with this extension are considered.log_level
str - Logging level for checkpointer module (Default: 'INFO').
#
AbstractCheckpointer ObjectsAbstract class for a checkpointer.
To create a custom checkpointer, users must implement the abstract methods of this class and pass along an instance to ModuleInterface. Custom checkpointers can be used at other stages of the training lifecycle via callbacks.
#
saveCreates a checkpoint by saving a Checkpoint dataclass containing any relevant states.
Arguments:
checkpoint_state
Checkpoint - Checkpointed states.index
int - Using epoch as index is suggested.force
bool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint models after the last epoch.
#
save_modelCreates a model checkpoint by saving model state.
Arguments:
model_state
Dict - Model state as provided by ModuleInterface.index
int - Number to use to create a unique filename. Using epoch as index is suggested.
#
loadLoad and return a checkpointed file.
Implements logic to load a checkpointed file as configured via args used when constructing the checkpointer object. Always called upon initialization of Trainer.
Returns:
Checkpoint
- Checkpointed states.
#
DefaultCheckpointer ObjectsDefault checkpointer implementation, implements AbstractCheckpointer and contains a few helper functions for managing checkpointed files.
Must be initialized with DefaultCheckpointerArguments.
#
__init__Initialize checkpointer and delete existing checkpointed files under save_dir if delete_existing_checkpoints is set to True.
#
saveCreates a checkpoint by saving a Checkpoint dataclass containing any relevant states as a python Dict.
Evaluates conditions and, if met, saves a provided dataclass which should contain any states that users require to save as part of a checkpoint under args.save_dir. An additional index argument is required to create a unique name for the file to be saved. The optional force flag will disregard conditions other than the checkpoint flag that enables this behavior. The condition for saving with DefaultCheckpointer is index being a multiple of the args.period.
Arguments:
checkpoint_state
Checkpoint - Checkpointed states.index
int - Number to use to create a unique filename and evaluate conditions for checkpointing. Using epoch as index is suggested.force
bool, optional - Saves checkpoint regardless of conditions if args.checkpoint is set to True. Used to always checkpoint states after the last epoch.
Returns:
str
- Path to checkpointed file.
#
save_modelCheckpoints a model state, leveraging torch.save().
Evaluates if checkpointing is enabled and if a model save directory has been set, and saves a provided model state. An additional index argument is required to create a unique name for the file to be saved.
Arguments:
model_state
Dict - Model state as provided by ModuleInterface.index
int - Number to use to create a unique filename. Using epoch as index is suggested.
Returns:
str
- Path to checkpointed file.
#
loadAttempt to load and return a checkpointed file leveraging torch.load(). The checkpoined file is assumed to be created with save() and thus be a python Dict.
This method is always called upon initialization of Trainer. Searches for and attempts to load a checkpointed file based on args. If no load_dir is set, returns None. If a load_dir and load_filename have been set, the file "load_filename" under load_dir is directly loaded (the filename must include extension). If only load_dir is set, get_latest_file() is called to seach the folder for the file with the largest integer (index) in its filename, and returns that path for loading.
Returns:
Checkpoint
- Checkpointed states.
#
get_latest_fileGet the path to the last checkpointed file.
Find and return the path of the file with greatest number of completed epochs under dirpath (recursive search) for a given file prefix, and optionally file extension.
Arguments:
load_dir
str - Directory under which to search for checkpointed files.file_prefix
str - Prefix to match for when searching for candidate files.file_ext
str, optional - File extension to consider when searching.
Returns:
str
- Path to latest checkpointed file.
#
check_mk_dirCheck if the path exists, and if it doesn't creates it.
Arguments:
dirpath
str - Directory under which to search for checkpointed files.