Checkpointing
The pymarlin.utils,checkpointer.checkpoint_utils
module cointains the
AbstractCheckpointer
class that can be extended and
passed along to pymarlin.core.trainer.Trainer
for checkpointing. A default implementation
is provided via DefaultCheckpointer
in case no checkpointer is passed along to
pymarlin.core.trainer.Trainer
.
Users can control the DefaultCheckpointer
behavior via the
DefaultCheckpointerArguments
,
and modify the arguments dataclass for their own checkpointers.
Here is an example of how to create your own checkpointer:
Recall that these three methods are called automatically by pymarlin.core.trainer.Trainer
:
- load(): at
pymarlin.core.trainer.Trainer
inicialization, before training. - save(): at the end of every epoch and once more after training with force=True.
- save_model(): at the end of training.
Please review the
AbstractCheckpointer
documentation for precise method signatures for correctly interfacing withpymarlin.core.trainer.Trainer
if creating a custom checkpointer. To customize what is checkpointed as a part of the attributes ofCheckpoint
, please override the get_state() methods atModuleInterface.get_state()
,Trainer.get_state()
andTrainerBackend.get_state()
. For example, forTrainer.get_state()
:
Please remember to also update update_state() methods if appropriate.