Source code for archai.trainers.nlp.hf_callbacks

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
from typing import Dict, Optional

from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments


[docs]class BPCTrainerCallback(TrainerCallback): """A `TrainerCallback` that adds bits per character metrics to the logs.""" def __init__(self, *args, **kwargs) -> None: """Initialize the `BPCTrainerCallback` with custom arguments and keyword arguments.""" super().__init__(*args, **kwargs)
[docs] def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: """Add bits per character metrics to the training logs. Args: args: The training arguments. state: The trainer state. control: The trainer control. """ current_log = state.log_history[-1] # Check whether the last log comes from the training step if "loss" in current_log: try: current_log["bpc"] = current_log["loss"] / math.log(2) except OverflowError: current_log["bpc"] = math.inf # Check whether the last log comes from the evaluation step if "eval_loss" in current_log: try: current_log["eval_bpc"] = current_log["eval_loss"] / math.log(2) except OverflowError: current_log["eval_bpc"] = math.inf
[docs] def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: Optional[Dict[str, float]] = None, **kwargs ) -> None: """Add bits per character metrics to the evaluation metrics. Args: args: The training arguments. state: The trainer state. control: The trainer control. metrics: The evaluation metrics. """ # Checks whether metrics have validation loss if "eval_loss" in metrics: try: metrics["eval_bpc"] = metrics["eval_loss"] / math.log(2) except OverflowError: metrics["eval_bpc"] = math.inf # Checks whether metrics have testing loss if "test_loss" in metrics: try: metrics["test_bpc"] = metrics["test_loss"] / math.log(2) except OverflowError: metrics["test_bpc"] = math.inf
[docs]class PerplexityTrainerCallback(TrainerCallback): """A `TrainerCallback` that adds perplexity metrics to the logs.""" def __init__(self, *args, **kwargs) -> None: """Initialize the `PerplexityTrainerCallback` with custom arguments and keyword arguments.""" super().__init__(*args, **kwargs)
[docs] def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: """Add perplexity metrics to the training logs. Args: args: The training arguments. state: The trainer state. control: The trainer control. """ current_log = state.log_history[-1] # Checks whether last log comes from training step if "loss" in current_log: try: current_log["ppl"] = math.exp(current_log["loss"]) except OverflowError: current_log["ppl"] = math.inf # Checks whether last log comes from evaluation step if "eval_loss" in current_log: try: current_log["eval_ppl"] = math.exp(current_log["eval_loss"]) except OverflowError: current_log["eval_ppl"] = math.inf
[docs] def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: Optional[Dict[str, float]] = None, **kwargs ) -> None: """Add perplexity metrics to the evaluation metrics. Args: args: The training arguments. state: The trainer state. control: The trainer control. metrics: The evaluation metrics. """ # Checks whether metrics have validation loss if "eval_loss" in metrics: try: metrics["eval_ppl"] = math.exp(metrics["eval_loss"]) except OverflowError: metrics["eval_ppl"] = math.inf # Checks whether metrics have testing loss if "test_loss" in metrics: try: metrics["test_ppl"] = math.exp(metrics["test_loss"]) except OverflowError: metrics["test_ppl"] = math.inf