Source code for archai.trainers.losses

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss

[docs]class SmoothCrossEntropyLoss(_WeightedLoss): """Cross entropy loss with label smoothing support.""" def __init__( self, weight: Optional[torch.Tensor] = None, reduction: Optional[str] = "mean", smoothing: Optional[float] = 0.0 ) -> None: """Initialize the loss function. Args: weight: Weight for each class. reduction: Reduction method. smoothing: Label smoothing factor. """ super().__init__(weight=weight, reduction=reduction) self.smoothing = smoothing self.weight = weight self.reduction = reduction @staticmethod def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing: Optional[float] = 0.0) -> torch.Tensor: assert 0 <= smoothing < 1 with torch.no_grad(): # For label smoothing, we replace 1-hot vector with 0.9-hot vector instead. # Create empty vector of same size as targets, fill them up with smoothing/(n-1) # then replace element where 1 supposed to go and put there 1-smoothing instead targets = ( torch.empty(size=(targets.size(0), n_classes), device=targets.device) .fill_(smoothing / (n_classes - 1)) .scatter_(1,, 1.0 - smoothing) ) return targets
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1), self.smoothing) lsm = F.log_softmax(inputs, -1) if self.weight is not None: # To support weighted targets lsm = lsm * self.weight.unsqueeze(0) loss = -(targets * lsm).sum(-1) if self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": loss = loss.mean() return loss