Trainer Backend module:
Currently we support:
These are TrainerBackends
for most common scenarios available out of the box.
Alternatively a user can provide a custom TrainerBackend
build_trainer_backendFactory for trainer_backends
str - TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apexargs
sequence - TrainerBackend positional argumentskwargs
dict - TrainerBackend keyword arguments
TrainerBackendArguments ObjectsTrainer Backend Arguments dataclass.
TrainerBackend ObjectsTrainer Backend abstract class.
OutputCollector ObjectsResponsible for collecting step outputs and stores them in memory across each call. Concatinates tensors from all steps across first dimension.
collectCoalesces train_step and val_step outputs. all tensors concatenated across dimension 0 if input is a torch.Tensor of dimension batch_size x y .., all_outputs will be List[torch.Tensor of dimension total_samples_till_now x y] if input is a torch.Tensor of dimension 1 1, all_outputs will List[torch.Tensor of dimension total_samples_till_now 1] if input is List[torch.Tensor], all_outputs will be List[torch.Tensor] - all tensors concatenated across dimension 0
Union[torch.Tensor, Iterable[torch.Tensor]] - train_step , val_step outputs
SingleProcess ObjectsSingle Process Trainer Backend
__init__Single process trainer_backend
process_global_stepClip gradients and call optimizer + scheduler
get_stateGet the current state of the trainer_backend, used for checkpointing.
dict - Dictionary of variables or objects to checkpoint.
update_stateUpdate the trainer_backend from a checkpointed state.
state (dict) : Output of get_state() during checkpointing
SingleProcessDpSgd ObjectsBackend which supports Differential Privacy. We are using Opacus library.
SingleProcessAmp ObjectsSingleProcess + Native PyTorch AMP Trainer Backend
SingleProcessApexAmp ObjectsSingleProcess + Apex AMP Trainer Backend
AbstractTrainerBackendDecorator ObjectsAbstract class implementing the decorator design pattern.
DDPTrainerBackend ObjectsDistributed Data Parallel TrainerBackend.
Wraps ModuleInterface model with DistributedDataParallel which handles gradient averaging across processes.
.. note: Assumes initiailized model parameters are consistent across processes - e.g. by using same random seed in each process at point of model initialization.
setup_distributed_envSetup the process group for distributed training.
cleanupDestroy the process group used for distributed training.
gather_tensors_on_cpuGather tensors and move to cpu at configurable frequency.
Move tensor to CUDA device, apply all-gather and move back to CPU.
If distributed_training_args.gather_frequency
is set, tensors are
moved to CUDA in chunks of that size.
torch.tensor - To be gathered.
Gathered tensor on the cpu.
DPDDPTrainerBackend ObjectsDistributed Data Parallel TrainerBackend with Differential Privacy.
Wraps ModuleInterface model with DifferentiallyPrivateDistributedDataParallel which handles gradient averaging across processes, along with virtual stepping.
.. note: Assumes initiailized model parameters are consistent across processes - e.g. by using same random seed in each process at point of model initialization.