core.module_interface
Module Interface module:
This module contains the abstract classes CallbackInterface and ModuleInterface that can provide everything necessary for model training. Users should implement these abstract classes in their Scenarios.
Stage Objects#
Stages: train, val, test
CallbackInterface Objects#
A callback class used to add scenario specific outputs/logging/debugging during training.
on_begin_train_epoch#
Hook before training epoch (before model forward).
Arguments:
global_stepint - [description]epochint - Current training epoch
on_end_train_step#
Runs after end of a global training step.
Arguments:
global_stepint - current global steptrain_step_collated_outputslist - all train step outputs in a list. If train_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]
on_end_train_epoch#
Hook after training epoch.
Arguments:
global_stepint - [description]train_step_collated_outputslist - all train step outputs in a list. If train_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]
on_end_backward#
Hook after each backward
Arguments:
global_stepint - [description]loss_tensor(torch.Tensor)- Undetached loss tensor
on_end_val_epoch#
Update value at end of end of end of variable
Arguments:
global_stepint - [description] val_step_collated_outputs : all val step outputs in a list. If val_step returns loss, logits train_step_collated_outputs will have [loss_collated, logits_collated]keystr, optional - The id of the validation dataloader. Defaults to "default".
on_end_train#
Hook after training finishes
Arguments:
global_stepint - [description]
ModuleInterface Objects#
Interface for PyTorch modules.
This interface contains model architecture in the form of a PyTorch
nn.Module together with optimizers and schedules, train and validation
step recipes and any callbacks.
Note: The forward function is overridden.
Note: Users are encouraged to override the train_step and val_step
methods.
get_optimizers_schedulers#
Returns a list of optimizers and schedulers that are used to instantiate the optimizers .
Returns:
Tuple[Iterable[torch.optim.Optimizer], Iterable]: list of optimizers and list of schedulers
get_train_dataloader#
Returns a dataloader for the training loop . Called every epoch.
Arguments:
samplertype - data sampler type which is a derived class of torch.utils.data.Sampler Create concrete sampler object before creating dataloader.batch_sizeint - batch size per step per device
Returns:
torch.utils.data.DataLoader- Training dataloader
Example:
train_ds = self.data.get_train_dataset() dl = DataLoader(train_ds, batch_size = batch_size, collate_fn= self.collate_fin, sampler = sampler(train_ds)) return dl
get_val_dataloaders#
Returns dataloader(s) for validation loop . Supports multiple dataloaders based on key value. Keys will be passed in the callback functions. Called every epoch .
Arguments:
samplertype - data sampler type which is a derived class of torch.utils.data.Sampler Create concrete sampler object before creating dataloader.batch_sizeint - validation batch size per step per device
Returns:
Union[ Dict[str, torch.utils.data.DataLoader], torch.utils.data.DataLoader ]: A single dataloader or a dictionary of dataloaders with key as the data id and value as dataloader
get_test_dataloaders#
Returns test dataloaders
Arguments:
sampler[type] - [description]batch_size[type] - [description]
forward#
torch.nn.Module's forward() function. Overridden to call train_step() or val_step() based on stage .
Arguments:
stageStage - trian/val/testglobal_stepint - current global stepbatch[type] - output of dataloader stepdeviceUnion[torch.device, str, int] - device
Raises:
AttributeError- if stage is different than train, val, test
train_step#
Train a single train step . Batch should be moved to device before any operation.
Arguments:
global_stepint - [description]batch[type] - output of train dataloader stepdeviceUnion[torch.device, str, int] - device
Returns:
Union[torch.Tensor, Iterable[torch.Tensor]]: The first return value must be the loss tensor. Can return more than one values in output. All outputs must be tensors Callbacks will collate all outputs.
val_step#
Runs a single Validation step .
Arguments:
global_stepint - [description]batch[type] - [description]device[type] - [description]
Returns:
Union[torch.Tensor, Iterable[torch.Tensor]]: values that need to be collected - loss, logits etc. All outputs must be tensors
test_step#
Runs a single test step .
Arguments:
global_stepint - [description]batch[type] - [description]device[type] - [description]
get_state#
Get the current state of the module, used for checkpointing.
Returns:
Dict- Dictionary of variables or objects to checkpoint.
update_state#
Update the module from a checkpointed state.
Arguments:
stateDict - Output of get_state() during checkpointing.