Source code for archai.datasets.cv.transforms.lighting

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List

import torch


[docs]class Lighting: """Lighting transform.""" def __init__(self, std: float, eigval: List[float], eigvec: List[float]) -> None: """Initialize the lighting transform. Args: std: Standard deviation of the normal distribution. eigval: Eigenvalues of the covariance matrix. eigvec: Eigenvectors of the covariance matrix. """ self.std = std self.eigval = torch.Tensor(eigval) self.eigvec = torch.Tensor(eigvec) def __call__(self, img: torch.Tensor) -> torch.Tensor: if self.std == 0: return img alpha = img.new().resize_(3).normal_(0, self.std) rgb = ( self.eigvec.type_as(img) .clone() .mul(alpha.view(1, 3).expand(3, 3)) .mul(self.eigval.view(1, 3).expand(3, 3)) .sum(1) .squeeze() ) return img.add(rgb.view(3, 1, 1).expand_as(img))