Skip to main content

core.trainer_backend

Trainer Backend module:

Currently we support:

1. SingleProcess
2. SingleProcess Amp
3. SingleProcess Apex-Amp
4. DDP
5. DDP Amp
6. DDP Apex-Amp

These are TrainerBackends for most common scenarios available out of the box. Alternatively a user can provide a custom TrainerBackend.

build_trainer_backend#

def build_trainer_backend(trainer_backend_name, *args, **kwargs)

Factory for trainer_backends

Arguments:

  • trainer_backend_name str - TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apex
  • args sequence - TrainerBackend positional arguments
  • kwargs dict - TrainerBackend keyword arguments

TrainerBackendArguments Objects#

@dataclasses.dataclass
class TrainerBackendArguments()

Trainer Backend Arguments dataclass.

TrainerBackend Objects#

class TrainerBackend(ABC)

Trainer Backend abstract class.

OutputCollector Objects#

class OutputCollector()

Responsible for collecting step outputs and stores them in memory across each call. Concatinates tensors from all steps across first dimension.

collect#

def collect(outputs: Union[torch.Tensor, Iterable[torch.Tensor]])

Coalesces train_step and val_step outputs. all tensors concatenated across dimension 0 if input is a torch.Tensor of dimension batch_size x y .., all_outputs will be List[torch.Tensor of dimension total_samples_till_now x y] if input is a torch.Tensor of dimension 1 1, all_outputs will List[torch.Tensor of dimension total_samples_till_now 1] if input is List[torch.Tensor], all_outputs will be List[torch.Tensor] - all tensors concatenated across dimension 0

Arguments:

  • outputs Union[torch.Tensor, Iterable[torch.Tensor]] - train_step , val_step outputs

SingleProcess Objects#

class SingleProcess(TrainerBackend)

Single Process Trainer Backend

__init__#

def __init__()

Single process trainer_backend

process_global_step#

def process_global_step(global_step_collector, callback)

Clip gradients and call optimizer + scheduler

get_state#

def get_state() -> dict

Get the current state of the trainer_backend, used for checkpointing.

Returns:

  • state_dict dict - Dictionary of variables or objects to checkpoint.

update_state#

def update_state(state) -> None

Update the trainer_backend from a checkpointed state.

Arguments:

state (dict) : Output of get_state() during checkpointing

SingleProcessDpSgd Objects#

class SingleProcessDpSgd(SingleProcess)

Backend which supports Differential Privacy. We are using Opacus library. https://opacus.ai/api/privacy_engine.html

SingleProcessAmp Objects#

class SingleProcessAmp(SingleProcess)

SingleProcess + Native PyTorch AMP Trainer Backend

SingleProcessApexAmp Objects#

class SingleProcessApexAmp(SingleProcessAmp)

SingleProcess + Apex AMP Trainer Backend

AbstractTrainerBackendDecorator Objects#

class AbstractTrainerBackendDecorator(TrainerBackend)

Abstract class implementing the decorator design pattern.

DDPTrainerBackend Objects#

class DDPTrainerBackend(AbstractTrainerBackendDecorator)

Distributed Data Parallel TrainerBackend.

Wraps ModuleInterface model with DistributedDataParallel which handles gradient averaging across processes.

.. note: Assumes initiailized model parameters are consistent across processes - e.g. by using same random seed in each process at point of model initialization.

setup_distributed_env#

def setup_distributed_env()

Setup the process group for distributed training.

cleanup#

def cleanup()

Destroy the process group used for distributed training.

gather_tensors_on_cpu#

def gather_tensors_on_cpu(x: torch.tensor)

Gather tensors and move to cpu at configurable frequency.

Move tensor to CUDA device, apply all-gather and move back to CPU. If distributed_training_args.gather_frequency is set, tensors are moved to CUDA in chunks of that size.

Arguments:

  • x torch.tensor - To be gathered.

Returns:

Gathered tensor on the cpu.

DPDDPTrainerBackend Objects#

class DPDDPTrainerBackend(DDPTrainerBackend)

Distributed Data Parallel TrainerBackend with Differential Privacy.

Wraps ModuleInterface model with DifferentiallyPrivateDistributedDataParallel which handles gradient averaging across processes, along with virtual stepping.

.. note: Assumes initiailized model parameters are consistent across processes - e.g. by using same random seed in each process at point of model initialization.