Trainer module:
The Trainer
is responsible for coordinating the model definition
) and the TrainerBackend
- connecting the high-level
model recipe with the backend on which it will be trained.
This accepts a module
implementing ModuleInterface
that contains the
model definition, as well as the definition of train and evaluation steps,
optimizers and schedulers and any optional callbacks.
It also accepts a TrainerBackend
defining how the training should be run
e.g. single node vs distributed training. There are TrainerBackends
most common scenarios available out of the box - or alternatively a user can
provide a custom TrainerBackend
TrainerArguments ObjectsTrainer Arguments class.
AbstractTrainer ObjectsAbstract Trainer class.
trainRun Train loop
validateRun eval loop
Trainer ObjectsOrchestrates model training.
ModuleInterface - Contains model definition, train and validation definition, optimizer and scheduler, and optional callbacks.args
TrainerArguments - Training hyperparameters.Optional keyword arguments:
TrainerBackend - How the training will be carried out. For example, the training is distributed and/or using AMP (automatic mixed precision). This can also be specified in args using the backend keyword. Defaults to singleprocess. Options are: sp (singleprocess), sp-amp, ddp, ddp-amp.checkpointer
AbstractCheckpointer - Used to handle model checkpointing.
__init__Initializes stats, writers, trainer_backend and other helper functions
trainTrain and validate the model
validateRun evaluation over multiple validation dataloaders
save_checkpointCheckpoint the current state of the Trainer, TrainerBackend, and ModuleInterface.
Saves state of each object in a dictionary by calling on their get_state() methods and providing the states to the checkpointer's save() method.
save_model_checkpointCheckpoint the current state of the ModuleInterface, used to save the final model in the training loop.
Saves state of the ModuleInterface by calling on it's get_state() method and providing it to the checkpointer's save_model() method.
load_checkpointsLoad state of Trainer, TrainerBackend, and ModuleInterface from checkpoint.
Loading logic is determined by the checkpointer used, see DefaultCheckpointer for default implementation logic. If a checkpoint is loaded, all module states are updated.
get_stateGet the current state of the Trainer for checkpointing.
Default implementation returns epochs finished, override to include aditional state properties.
dict - Dictionary of variables or objects to checkpoint.
update_stateUpdate the Trainer's state from a checkpointed state.
state : Output of get_state() during checkpointing.
deviceThe torch device either CPU or GPU, and the local rank.
Note: _fetch_rank() should have already been called before calling device.
train_step_batch_sizeReturns micro batch sizes for training. Splits batch into smaller step batches such that
estimated_global_steps_per_epochEstimate the number of global steps per epoch.
Compute the maximum number of global steps as len(dataloader) // gradient_accumulation + 1. If max_train_steps_per_epoch is provided we return the minimum of the two.
Note: SequentialSampler is used to get the train dataloader regardless of the sampler provided by trainer_backend as we only require the length of the dataloader.
Do not change this logic without testing thorougly. There is a test case already written.
TODO: simplify this by initiliaizing distributed environment before calling this and remove SequentialSampler.