Checkpointing
The pymarlin.utils,checkpointer.checkpoint_utils module cointains the
AbstractCheckpointer class that can be extended and
passed along to pymarlin.core.trainer.Trainer for checkpointing. A default implementation
is provided via DefaultCheckpointer in case no checkpointer is passed along to
pymarlin.core.trainer.Trainer.
Users can control the DefaultCheckpointer behavior via the
DefaultCheckpointerArguments,
and modify the arguments dataclass for their own checkpointers.
Here is an example of how to create your own checkpointer:
Recall that these three methods are called automatically by pymarlin.core.trainer.Trainer:
- load(): at
pymarlin.core.trainer.Trainerinicialization, before training. - save(): at the end of every epoch and once more after training with force=True.
- save_model(): at the end of training.
Please review the
AbstractCheckpointerdocumentation for precise method signatures for correctly interfacing withpymarlin.core.trainer.Trainerif creating a custom checkpointer. To customize what is checkpointed as a part of the attributes ofCheckpoint, please override the get_state() methods atModuleInterface.get_state(),Trainer.get_state()andTrainerBackend.get_state(). For example, forTrainer.get_state():
Please remember to also update update_state() methods if appropriate.