core.module_interface
Module Interface module:
This module contains the abstract classes CallbackInterface and ModuleInterface that can provide everything necessary for model training. Users should implement these abstract classes in their Scenarios.
#
Stage ObjectsStages: train, val, test
#
CallbackInterface ObjectsA callback class used to add scenario specific outputs/logging/debugging during training.
#
on_begin_train_epochHook before training epoch (before model forward).
Arguments:
global_step
int - [description]epoch
int - Current training epoch
#
on_end_train_stepRuns after end of a global training step.
Arguments:
global_step
int - current global steptrain_step_collated_outputs
list - all train step outputs in a list. If train_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]
#
on_end_train_epochHook after training epoch.
Arguments:
global_step
int - [description]train_step_collated_outputs
list - all train step outputs in a list. If train_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]
#
on_end_backwardHook after each backward
Arguments:
global_step
int - [description]loss_tensor(torch.Tensor)
- Undetached loss tensor
#
on_end_val_epochUpdate value at end of end of end of variable
Arguments:
global_step
int - [description] val_step_collated_outputs : all val step outputs in a list. If val_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]key
str, optional - The id of the validation dataloader. Defaults to "default".
#
on_end_trainHook after training finishes
Arguments:
global_step
int - [description]
#
ModuleInterface ObjectsInterface for PyTorch modules.
This interface contains model architecture in the form of a PyTorch
nn.Module
together with optimizers and schedules, train and validation
step recipes and any callbacks.
Note: The forward function is overridden.
Note: Users are encouraged to override the train_step
and val_step
methods.
#
get_optimizers_schedulersReturns a list of optimizers and schedulers that are used to instantiate the optimizers .
Returns:
Tuple[Iterable[torch.optim.Optimizer], Iterable]: list of optimizers and list of schedulers
#
get_train_dataloaderReturns a dataloader for the training loop . Called every epoch.
Arguments:
sampler
type - data sampler type which is a derived class of torch.utils.data.Sampler Create concrete sampler object before creating dataloader.batch_size
int - batch size per step per device
Returns:
torch.utils.data.DataLoader
- Training dataloader
Example:
train_ds = self.data.get_train_dataset() dl = DataLoader(train_ds, batch_size = batch_size, collate_fn= self.collate_fin, sampler = sampler(train_ds)) return dl
#
get_val_dataloadersReturns dataloader(s) for validation loop . Supports multiple dataloaders based on key value. Keys will be passed in the callback functions. Called every epoch .
Arguments:
sampler
type - data sampler type which is a derived class of torch.utils.data.Sampler Create concrete sampler object before creating dataloader.batch_size
int - validation batch size per step per device
Returns:
Union[ Dict[str, torch.utils.data.DataLoader], torch.utils.data.DataLoader ]: A single dataloader or a dictionary of dataloaders with key as the data id and value as dataloader
#
get_test_dataloadersReturns test dataloaders
Arguments:
sampler
[type] - [description]batch_size
[type] - [description]
#
forwardtorch.nn.Module's forward() function. Overridden to call train_step() or val_step() based on stage .
Arguments:
stage
Stage - trian/val/testglobal_step
int - current global stepbatch
[type] - output of dataloader stepdevice
Union[torch.device, str, int] - device
Raises:
AttributeError
- if stage is different than train, val, test
#
train_stepTrain a single train step . Batch should be moved to device before any operation.
Arguments:
global_step
int - [description]batch
[type] - output of train dataloader stepdevice
Union[torch.device, str, int] - device
Returns:
Union[torch.Tensor, Iterable[torch.Tensor]]: The first return value must be the loss tensor. Can return more than one values in output. All outputs must be tensors Callbacks will collate all outputs.
#
val_stepRuns a single Validation step .
Arguments:
global_step
int - [description]batch
[type] - [description]device
[type] - [description]
Returns:
Union[torch.Tensor, Iterable[torch.Tensor]]: values that need to be collected - loss, logits etc. All outputs must be tensors
#
test_stepRuns a single test step .
Arguments:
global_step
int - [description]batch
[type] - [description]device
[type] - [description]
#
get_stateGet the current state of the module, used for checkpointing.
Returns:
Dict
- Dictionary of variables or objects to checkpoint.
#
update_stateUpdate the module from a checkpointed state.
Arguments:
state
Dict - Output of get_state() during checkpointing.