# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import itertools
import math
import os
import shutil
import sys
import time
from typing import Any, Dict, Iterator, Optional, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from overrides import overrides
from packaging import version
from torch.nn.parallel import DistributedDataParallel
from archai.api.trainer_base import TrainerBase
from archai.common.distributed_utils import all_reduce, sync_workers
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.datasets.nlp.nvidia_data_loader_utils import (
LMMultiFileIterator,
LMOrderedIterator,
)
from archai.datasets.nlp.nvidia_dataset_provider import NvidiaDatasetProvider
from archai.quantization.mixed_qat import MixedQAT
from archai.quantization.qat import prepare_with_qat, qat_to_float_modules
from archai.trainers.cyclic_cosine_scheduler import CyclicCosineDecayLR
from archai.trainers.lamb_optimizer import JITLamb, Lamb
from archai.trainers.nlp.nvidia_training_args import NvidiaTrainingArguments
logger = OrderedDictLogger(source=__name__)
[docs]def save_checkpoint(
output_dir: str,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
scaler: torch.cuda.amp.GradScaler,
trainer_state: Dict[str, Any],
fp16: bool,
prefix: Optional[str] = "",
save_all_checkpoints: Optional[bool] = False,
is_best_model: Optional[bool] = False,
) -> None:
"""Save a checkpoint that holds enough information to resume the training.
The checkpoint contains the model's configuration and state, the optimizer's state,
the scheduler's state, the scaler's state (if FP16 precision is used),
and the trainer's state.
If `is_best_model` is `True`, the function will also save a copy of the checkpoint
with the prefix "checkpoint-best".
If `save_all_checkpoints` is `True`, the function will also save a copy of the checkpoint
with the step number in the file name.
Args:
output_dir: Folder where checkpoint should be saved.
model: Instance of model.
optimizer: Instance of optimizer.
scheduler: Instance of scheduler.
scaler: Instance of scaler.
trainer_state: Current trainer state.
fp16: Whether fp16 precision is used or not.
prefix: Prefix which should be added to the checkpoint's file name.
save_all_checkpoints: Whether all `eval_steps` steps should be saved.
is_best_model: Whether best model should be saved.
"""
state = {
"model_config": model.config,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict() if scheduler else None,
"scaler_state": scaler.state_dict() if fp16 else None,
"trainer_state": trainer_state,
}
checkpoint_name = prefix + "checkpoint-last.pt"
with sync_workers() as rank:
checkpoint_path = os.path.join(output_dir, checkpoint_name)
if rank == 0:
logger.info(f"Saving checkpoint: {checkpoint_path}")
torch.save(state, checkpoint_path)
if is_best_model:
checkpoint_step_name = prefix + "checkpoint-best.pt"
checkpoint_step_path = os.path.join(output_dir, checkpoint_step_name)
logger.info(f"Saving checkpoint: {checkpoint_step_path}")
shutil.copy(checkpoint_path, checkpoint_step_path)
if save_all_checkpoints:
checkpoint_step_name = prefix + f"checkpoint-{trainer_state['step']}.pt"
checkpoint_step_path = os.path.join(output_dir, checkpoint_step_name)
logger.info(f"Saving checkpoint: {checkpoint_step_path}")
shutil.copy(checkpoint_path, checkpoint_step_path)
[docs]class NvidiaTrainer(TrainerBase):
"""NVIDIA-based trainer."""
def __init__(
self,
model: torch.nn.Module,
args: Optional[NvidiaTrainingArguments] = None,
) -> None:
"""Initialize by verifying the model and training arguments, and loading dataset.
Args:
model: Model to be trained or evaluated.
args: NVIDIA-based training arguments. If not provided, a default instance
of `NvidiaTrainingArguments` will be used.
"""
assert isinstance(model, torch.nn.Module), "`model` should be an instance of `torch.nn.Module`."
self.model = model
if args is None:
args = NvidiaTrainingArguments("tmp_trainer")
assert isinstance(args, NvidiaTrainingArguments), "`args` should be an instance of `NvidiaTrainingArguments`."
self.args = args
self.dataset_provider = NvidiaDatasetProvider(
dataset_name=self.args.dataset_name,
dataset_dir=self.args.dataset_dir,
cache_dir=self.args.dataset_cache_dir,
vocab_type=self.args.vocab_type,
vocab_size=self.args.vocab_size,
refresh_cache=self.args.dataset_refresh_cache,
)
self.model.to(self.args.device)
self.trainer_state = {
"iterator": 0,
"epoch": 0,
"batch": 0,
"step": 0,
"best_eval_loss": 1e300,
"log_history": [],
}
[docs] def load_checkpoint(self, checkpoint_file_path: str) -> Tuple[int, int, int, int]:
"""Load states from a checkpoint file.
Args:
checkpoint_file_path: Path to the checkpoint file.
Returns:
Current iterator, epoch, batch, and step values.
"""
try:
checkpoint = torch.load(checkpoint_file_path, map_location=self.args.device)
self.model.load_state_dict(checkpoint["model_state"])
self.optimizer.load_state_dict(checkpoint["optimizer_state"])
self.scheduler.load_state_dict(checkpoint["scheduler_state"])
if self.args.fp16:
self.scaler.load_state_dict(checkpoint["amp_state"])
self.trainer_state = checkpoint["trainer_state"]
iterator = self.trainer_state["iterator"]
start_epoch = self.trainer_state["epoch"]
start_batch = self.trainer_state["batch"]
step = self.trainer_state["step"]
return iterator, start_epoch, start_batch, step
except FileNotFoundError:
return 0, 0, 0, 0
def _get_dataloader(self, split: str) -> Iterator:
if split == "train":
input_ids = self.dataset_provider.get_train_dataset()
elif split == "valid":
input_ids = self.dataset_provider.get_val_dataset()
elif split == "test":
input_ids = self.dataset_provider.get_test_dataset()
else:
raise RuntimeError(f"Split: {split} is not supported yet.")
if self.args.dataset_name in ["wt2", "wt103"] or self.args.dataset_name.startswith("olx_"):
return LMOrderedIterator(
input_ids,
self.args.global_batch_size,
self.args.seq_len,
device=self.args.device,
)
elif self.args.dataset_name == "lm1b":
return LMMultiFileIterator(
input_ids,
self.vocab,
self.args.global_batch_size,
self.args.seq_len,
device=self.args.device,
)
else:
raise RuntimeError(f"Dataset: {self.args.dataset_name} is not supported yet.")
def _create_optimizer(self) -> None:
optimizer_name = self.args.optim.lower()
if optimizer_name == "sgd":
self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.learning_rate, momentum=self.args.momentum)
elif optimizer_name == "adam":
self.optimizer = optim.Adam(
self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay
)
elif optimizer_name == "adagrad":
self.optimizer = optim.Adagrad(self.model.parameters(), lr=self.args.learning_rate)
elif optimizer_name == "lamb":
self.optimizer = Lamb(
self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay
)
elif optimizer_name == "jitlamb":
self.optimizer = JITLamb(
self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay
)
else:
raise NotImplementedError(f"Optimizer: {self.args.optim} is not implemented yet.")
def _create_scaler(self) -> None:
self.scaler = None
if self.args.fp16:
self.scaler = torch.cuda.amp.GradScaler()
def _create_scheduler(self) -> None:
scheduler_name = self.args.lr_qat_scheduler_type if self.args.qat else self.args.lr_scheduler_type
if scheduler_name == "cosine":
if self.args.lr_scheduler_max_steps:
max_steps = self.args.lr_scheduler_max_steps
else:
max_steps = self.args.max_steps
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, max_steps - self.args.lr_scheduler_warmup_steps, eta_min=self.args.lr_scheduler_min_lr
)
elif scheduler_name == "inv_sqrt":
def lr_lambda(step: int) -> float:
if step == 0 and self.args.lr_scheduler_warmup_steps == 0:
return 1.0
else:
return (
1.0 / (step**0.5)
if step > self.args.lr_scheduler_warmup_steps
else step / (self.args.lr_scheduler_warmup_steps**1.5)
)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda)
elif scheduler_name == "cyclic_cosine":
init_decay_steps = int((self.args.max_step - self.args.lr_scheduler_warmup_steps) / 2)
restart_interval = int((self.args.max_step - self.args.lr_scheduler_warmup_steps) / 4)
self.scheduler = CyclicCosineDecayLR(
self.optimizer,
init_decay_steps,
self.args.lr_scheduler_min_lr,
restart_interval,
warmup_epochs=self.args.lr_scheduler_warmup_steps,
warmup_start_lr=self.args.learning_rate * 0.01,
)
elif scheduler_name == "constant":
pass
def _setup_qat(self) -> None:
if self.args.qat:
prepare_with_qat(self.model, onnx_compatible=True)
if self.args.mixed_qat:
self.model = MixedQAT(self.model)
def _setup_distributed_training(self) -> None:
self.dist_model = self.model
if self.args.strategy == "ddp" and torch.distributed.is_initialized():
self.dist_model = DistributedDataParallel(
self.model,
device_ids=[self.args.local_rank],
output_device=self.args.local_rank,
broadcast_buffers=False,
find_unused_parameters=self.args.find_unused_parameters,
)
elif self.args.strategy == "dp":
self.dist_model = nn.DataParallel(self.model, dim=1)
def _training_step_chunk(
self, input_ids: torch.LongTensor, labels: torch.LongTensor, autocast: torch.autocast
) -> float:
with autocast:
loss = self.dist_model(input_ids, labels=input_ids)[0]
loss = loss.float().mean().type_as(loss) / self.args.gradient_accumulation_steps
if self.args.fp16:
self.scaler.scale(loss).backward()
else:
loss.backward()
return loss.float().item()
def _training_step(
self,
train_dataloader: Iterator,
eval_dataloader: Iterator,
iterator: int,
epoch: int,
start_batch: int,
step: int,
) -> None:
self.model.train()
train_loss, log_step, n_labels_tokens = 0.0, 0, 0
best_eval_loss = self.trainer_state["best_eval_loss"]
start_time = time.time()
# `lm1b` uses a different style of data loader
if self.args.dataset_name != "lm1b":
train_iterator = train_dataloader.get_fixlen_iter(start=iterator)
else:
train_iterator = train_dataloader
# Support `bf16` based on PyTorch version and CUDA availability
autocast = torch.autocast(self.args.device.type, enabled=self.args.fp16)
if version.parse(torch.__version__) >= version.parse("1.10") and self.args.device.type != "cpu":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
autocast = torch.cuda.amp.autocast(enabled=self.args.fp16, dtype=dtype)
for batch, (input_ids, labels, _, _) in enumerate(train_iterator, start=start_batch + 1):
log_step += 1
n_labels_tokens += labels.numel()
for param in self.model.parameters():
param.grad = None
# Split into chunks for gradient accumulation
input_ids_chunks = torch.chunk(input_ids, self.args.gradient_accumulation_steps, 0)
labels_chunks = torch.chunk(labels, self.args.gradient_accumulation_steps, 0)
for i in range(self.args.gradient_accumulation_steps):
input_ids_chunk = input_ids_chunks[i].contiguous()
labels_chunk = labels_chunks[i].contiguous()
train_loss_chunk = self._training_step_chunk(
input_ids_chunk,
labels_chunk,
autocast,
)
train_loss += train_loss_chunk
if self.args.fp16:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
if self.args.fp16:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
# Learning rate annealing
step += 1
if self.args.lr_scheduler_type in ["cosine", "constant"]:
if step < self.args.lr_scheduler_warmup_steps:
curr_lr = self.args.learning_rate * step / self.args.lr_scheduler_warmup_steps
self.optimizer.param_groups[0]["lr"] = curr_lr
else:
if self.args.lr_scheduler_type == "cosine":
self.scheduler.step(step - self.args.lr_scheduler_warmup_steps)
elif self.args.lr_scheduler_type in ["inv_sqrt", "cyclic_cosine"]:
self.scheduler.step(step)
# Logging
if step % self.args.logging_steps == 0:
elapsed_time = time.time() - start_time
lr = self.optimizer.param_groups[0]["lr"]
loss = train_loss / log_step
loss = all_reduce(loss, op="mean")
batch_time = elapsed_time / log_step
batch_time = all_reduce(batch_time, op="max")
throughput = n_labels_tokens / elapsed_time
throughput = all_reduce(throughput, op="sum")
train_loss, log_step, n_labels_tokens = 0.0, 0, 0
self.trainer_state["log_history"].append(
{
"epoch": epoch,
"learning_rate": lr,
"loss": loss,
"ppl": math.exp(loss),
"step": step,
}
)
logger.info(
f"Epoch: {epoch} | Step: {step} | "
f"Batch: {batch} / {train_dataloader.n_batch} | LR: {lr:.3e} | "
f"ms/batch: {batch_time*1000:.1f} | tok/s: {throughput:.0f} | "
f"Loss: {loss:.3f} | PPL: {math.exp(loss):.3f}"
)
start_time = time.time()
do_periodic_eval = step % self.args.eval_steps == 0
is_final_step = step == self.args.max_steps
# Evaluation and checkpoint
if (do_periodic_eval or is_final_step) and self.args.do_eval:
eval_loss, eval_time = self._evaluation_step(eval_dataloader)
eval_loss = all_reduce(eval_loss, op="mean")
self.trainer_state["log_history"].append(
{
"epoch": epoch,
"eval_idx": (step // self.args.eval_steps) - 1,
"eval_runtime": eval_time,
"eval_loss": eval_loss,
"eval_ppl": math.exp(eval_loss),
"step": step,
}
)
logger.info(
f"Eval: {(step // self.args.eval_steps) - 1} | "
f"Step: {step} | Time: {eval_time:.2f}s | "
f"Loss: {eval_loss:.3f} | PPL: {math.exp(eval_loss):.3f}"
)
iterator = train_dataloader.last_iter
save_model = copy.deepcopy(self.model)
prefix = ""
self.trainer_state["iterator"] = iterator
self.trainer_state["epoch"] = epoch
self.trainer_state["batch"] = batch
self.trainer_state["step"] = step
# Model needs to be converted back to FP32 when using QAT
if self.args.qat:
qat_to_float_modules(save_model)
prefix = "qat-"
# Save original FP32 model when using MixedQAT
if self.args.mixed_qat:
save_model = save_model.model
prefix = "mixed-qat-"
# Check if current model is the best one
is_best_model = eval_loss < best_eval_loss
if is_best_model:
best_eval_loss = eval_loss
self.trainer_state["best_eval_loss"] = best_eval_loss
save_checkpoint(
self.args.output_dir,
save_model,
self.optimizer,
self.scheduler,
self.scaler,
self.trainer_state,
self.args.fp16,
prefix=prefix,
save_all_checkpoints=self.args.save_all_checkpoints,
is_best_model=is_best_model,
)
if is_final_step:
break
return step
[docs] @overrides
def train(self, checkpoint_file_path: Optional[str] = "") -> Dict[str, Any]:
"""Train a model.
Args:
checkpoint_file_path: Path to the checkpoint that will be used
to resume the training.
Returns:
Training-related metrics.
"""
self._create_optimizer()
self._create_scaler()
self._create_scheduler()
if checkpoint_file_path:
iterator, start_epoch, start_batch, step = self.load_checkpoint(checkpoint_file_path)
else:
iterator, start_epoch, start_batch, step = 0, 0, 0, 0
if step >= self.args.max_steps:
sys.exit(1)
self._setup_qat()
self._setup_distributed_training()
train_dataloader = self._get_dataloader("train")
eval_dataloader = self._get_dataloader("valid")
logger.info("Starting training ...")
logger.debug(f"Training arguments: {self.args.to_dict()}")
start_time = time.time()
try:
for epoch in itertools.count(start=start_epoch):
if self.args.iterator_roll:
train_dataloader.roll(seed=self.args.seed + epoch)
step = self._training_step(train_dataloader, eval_dataloader, iterator, epoch, start_batch, step)
iterator, start_batch = 0, 0
if step == self.args.max_steps:
logger.info("End of training ...")
break
except KeyboardInterrupt:
logger.info("Exiting from training ...")
end_time = time.time()
train_time = end_time - start_time
logger.info(f"Training time: {train_time:.3f} seconds")
def _evaluation_step(self, eval_dataloader: Iterator) -> Tuple[float, float]:
self.model.eval()
eval_loss, n_tokens = 0.0, 0
start_time = time.time()
with torch.no_grad():
for _, (input_ids, _, _, warm) in enumerate(eval_dataloader):
loss = self.model(input_ids, labels=input_ids)[0]
tokens = input_ids.numel()
if warm:
eval_loss += tokens * loss.float().mean().item()
n_tokens += tokens
eval_loss /= n_tokens
end_time = time.time()
self.model.train()
return eval_loss, end_time - start_time
[docs] @overrides
def evaluate(self, eval_dataloader: Optional[Iterator] = None) -> Dict[str, Any]:
"""Evaluate a model.
Args:
eval_dataloader: Evaluation-based data loader. If not supplied, it will
default to the one available in pre-loaded dataset.
Returns:
Evaluation-related metrics.
"""
if not eval_dataloader:
eval_dataloader = self._get_dataloader("test")
eval_loss, eval_time = self._evaluation_step(eval_dataloader)
eval_metrics = {
"eval_time": eval_time,
"eval_loss": eval_loss,
"eval_ppl": math.exp(eval_loss),
"eval_bpc": eval_loss / math.log(2),
}
return eval_metrics
[docs] @overrides
def predict(self) -> None:
"""Predict with a model."""
raise NotImplementedError
[docs] def fine_tune_qat(self, model: Optional[torch.nn.Module] = None, checkpoint_file_path: Optional[str] = "") -> None:
"""Fine-tune a model with QAT.
Users are allowed to pass in a different model (e.g., without dropout) than the one
instantiated with `NvidiaTrainer`, as well as a pre-trained checkpoint file to load
the weights from a previous training.
Args:
model: Model to be fine-tuned.
checkpoint_file_path: Path to the checkpoint used to resume training.
"""
if model:
assert isinstance(model, torch.nn.Module), "`model` should be an instance of `torch.nn.Module`."
self.model = model.to(self.args.device)
# QAT-based arguments
self.args.max_steps = 10000
self.args.eval_steps = 1000
self.args.optim = "adam"
self.args.learning_rate /= 100
self.args.lr_scheduler_min_lr /= 100
self.args.lr_scheduler_warmup_steps = 1000
self.args.qat = True
self.args.mixed_qat = False
# Re-load the checkpoint and perform the fine-tuning
self.load_checkpoint(checkpoint_file_path)
self.train()