Source code for archai.trainers.nlp.hf_trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import shutil
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from overrides import overrides
from transformers.trainer import Trainer

from archai.api.trainer_base import TrainerBase
from archai.trainers.nlp.hf_training_args import DistillerTrainingArguments


[docs]class HfTrainer(Trainer, TrainerBase): """Hugging Face trainer.""" @overrides def _rotate_checkpoints(self, use_mtime: Optional[bool] = False, output_dir: Optional[str] = None) -> None: """Rotate checkpoints and cache them to Azure Storage. The `use_mtime` argument is always set to `False` to avoid having multiple checkpoints with the same timestamp when retrieving them from Azure Storage. This is because Azure Storage does not support sub-second precision for file timestamps. Args: use_mtime: Whether to use mtime to sort the checkpoints. output_dir: Folder to output the checkpoints. """ if self.args.save_total_limit is None or self.args.save_total_limit <= 0: return # Enforces use_mtime=False to avoid identical timestamps # when retrieving files from Azure Storage use_mtime = False # Check if we should delete older checkpoint(s) checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) if len(checkpoints_sorted) <= self.args.save_total_limit: return # If save_total_limit=1 with load_best_model_at_end=True, # we could end up deleting the last checkpoint, which # we don't do to allow resuming save_total_limit = self.args.save_total_limit if ( self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1 and checkpoints_sorted[-1] != self.state.best_model_checkpoint ): save_total_limit = 2 number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] for checkpoint in checkpoints_to_be_deleted: try: shutil.rmtree(checkpoint) except FileNotFoundError: pass
[docs]class HfDistillerTrainer(HfTrainer): """Hugging Face distillation-based trainer.""" def __init__(self, teacher_model: torch.nn.Module, **kwargs) -> None: """Initialize Hugging Face distillation-based trainer. Args: teacher_model: Pre-trained teacher model. """ self.teacher_model = teacher_model if "args" in kwargs: assert isinstance( kwargs["args"], DistillerTrainingArguments ), "`args` should be an instance of `DistillerTrainingArguments`." else: kwargs["args"] = DistillerTrainingArguments("tmp") super().__init__(**kwargs)
[docs] @overrides def compute_loss( self, model: torch.nn.Module, inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False, ) -> Tuple[torch.Tensor, ...]: """Override the computation of the loss function. The loss is a weighted sum of the student's loss, as computed by the original `HfTrainer`, and the KL divergence between the student and teacher models. Args: model: Student model. inputs: Input tensors. return_outputs: Whether outputs should be returned. Returns: (loss, outputs) or the loss tensor. """ student_outputs = model(**inputs) student_loss = student_outputs["loss"] student_logits = student_outputs["logits"] with torch.no_grad(): teacher_outputs = self.teacher_model(**inputs) teacher_logits = teacher_outputs["logits"] # Compute the KL divergence and KD losses kl_loss = nn.KLDivLoss(reduction="batchmean") kl_divergence = kl_loss( F.log_softmax(student_logits / self.args.temperature, dim=-1), F.softmax(teacher_logits / self.args.temperature, dim=-1), ) kd_loss = self.args.temperature**2 * kl_divergence # Weigh the final loss loss = self.args.alpha * student_loss + (1 - self.args.alpha) * kd_loss return (loss, student_outputs) if return_outputs else loss