Source code for archai.datasets.cv.transforms.custom_cutout
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# https://github.com/quark0/darts/blob/master/cnn/utils.py
import numpy as np
import torch
[docs]class CustomCutout:
"""Custom-based cutout transform."""
def __init__(self, length: int) -> None:
"""Initialize the custom-based cutout transform.
Args:
length: Length of the cutout.
"""
self.length = length
def __call__(self, img: torch.Tensor) -> torch.Tensor:
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.0
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img