Source code for archai.datasets.cv.transforms.brightness
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
from typing import Tuple
import torch
[docs]class Brightness:
"""Brightness transform."""
def __init__(self, value: float) -> None:
"""Initialize the brightness transform.
Args:
value: Brightness factor, e.g., 0 = no change, 1 = completely white,
-1 = completely black, <0 = darker, >0 = brighter.
"""
self.value = max(min(value, 1.0), -1.0)
def __call__(self, *imgs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
outputs = []
for idx, img in enumerate(imgs):
img = torch.clamp(img.float().add(self.value).type(img.type()), 0, 1)
outputs.append(img)
return outputs if idx > 1 else outputs[0]
[docs]class RandomBrightness:
"""Random brightness transform."""
def __init__(self, min_val: float, max_val: float) -> None:
"""Initialize the random brightness transform.
Args:
min_val: Minimum brightness factor.
max_val: Maximum brightness factor.
"""
self.values = (min_val, max_val)
def __call__(self, *imgs: Tuple[torch.Tensor, ...]) -> torch.Tensor:
value = random.uniform(self.values[0], self.values[1])
outputs = Brightness(value)(*imgs)
return outputs