Source code for archai.supergraph.datasets.augmentation

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# code in this file is adpated from rpmcruz/autoaugment
# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py

import random
from collections import defaultdict
from typing import List, Union

import numpy as np
import PIL
import PIL.ImageDraw
import PIL.ImageEnhance
import PIL.ImageOps

from archai.common.ordered_dict_logger import get_global_logger
from archai.datasets.cv.transforms.custom_cutout import CustomCutout
from archai.supergraph.datasets.aug_policies import (
    fa_reduced_cifar10,
    fa_reduced_svhn,
    fa_resnet50_rimagenet,
)

logger = get_global_logger()
_random_mirror = True


[docs]class Augmentation: def __init__(self, policies): self.policies = policies def __call__(self, img): for _ in range(1): policy = random.choice(self.policies) for name, pr, level in policy: if random.random() > pr: continue img = apply_augment(img, name, level) return img
[docs]def add_named_augs(transform_train, aug:Union[List, str], cutout:int): # TODO: recheck: total_aug remains None in original fastaug code total_aug = augs = None logger.info({'augmentation': aug}) if isinstance(aug, list): transform_train.transforms.insert(0, Augmentation(aug)) elif aug: if aug == 'fa_reduced_cifar10': transform_train.transforms.insert(0, Augmentation(fa_reduced_cifar10())) elif aug == 'fa_reduced_imagenet': transform_train.transforms.insert(0, Augmentation(fa_resnet50_rimagenet())) elif aug == 'fa_reduced_svhn': transform_train.transforms.insert(0, Augmentation(fa_reduced_svhn())) elif aug == 'arsaug': transform_train.transforms.insert(0, Augmentation(arsaug_policy())) elif aug == 'autoaug_cifar10': transform_train.transforms.insert(0, Augmentation(autoaug_paper_cifar10())) elif aug == 'autoaug_extend': transform_train.transforms.insert(0, Augmentation(autoaug_policy())) elif aug in ['default', 'inception', 'inception320']: pass else: raise ValueError('Augmentations not found: %s' % aug) # add cutout transform # TODO: use PyTorch built-in cutout logger.info({'cutout': cutout}) if cutout > 0: transform_train.transforms.append(CustomCutout(cutout)) return total_aug, augs
[docs]def ShearX(img, v): # [-0.3, 0.3] assert -0.3 <= v <= 0.3 if _random_mirror and random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
[docs]def ShearY(img, v): # [-0.3, 0.3] assert -0.3 <= v <= 0.3 if _random_mirror and random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
[docs]def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert -0.45 <= v <= 0.45 if _random_mirror and random.random() > 0.5: v = -v v = v * img.size[0] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
[docs]def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert -0.45 <= v <= 0.45 if _random_mirror and random.random() > 0.5: v = -v v = v * img.size[1] return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
[docs]def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert 0 <= v <= 10 if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
[docs]def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] assert 0 <= v <= 10 if random.random() > 0.5: v = -v return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
[docs]def Rotate(img, v): # [-30, 30] assert -30 <= v <= 30 if _random_mirror and random.random() > 0.5: v = -v return img.rotate(v)
[docs]def AutoContrast(img, _): return PIL.ImageOps.autocontrast(img)
[docs]def Invert(img, _): return PIL.ImageOps.invert(img)
[docs]def Equalize(img, _): return PIL.ImageOps.equalize(img)
[docs]def Flip(img, _): # not from the paper return PIL.ImageOps.mirror(img)
[docs]def Solarize(img, v): # [0, 256] assert 0 <= v <= 256 return PIL.ImageOps.solarize(img, v)
[docs]def Posterize(img, v): # [4, 8] assert 4 <= v <= 8 v = int(v) return PIL.ImageOps.posterize(img, v)
[docs]def Posterize2(img, v): # [0, 4] assert 0 <= v <= 4 v = int(v) return PIL.ImageOps.posterize(img, v)
[docs]def Contrast(img, v): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Contrast(img).enhance(v)
[docs]def Color(img, v): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Color(img).enhance(v)
[docs]def Brightness(img, v): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Brightness(img).enhance(v)
[docs]def Sharpness(img, v): # [0.1,1.9] assert 0.1 <= v <= 1.9 return PIL.ImageEnhance.Sharpness(img).enhance(v)
[docs]def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] assert 0.0 <= v <= 0.2 if v <= 0.: return img v = v * img.size[0] return CutoutAbs(img, v)
[docs]def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] # assert 0 <= v <= 20 if v < 0: return img w, h = img.size x0 = np.random.uniform(w) y0 = np.random.uniform(h) x0 = int(max(0, x0 - v / 2.)) y0 = int(max(0, y0 - v / 2.)) x1 = min(w, x0 + v) y1 = min(h, y0 + v) xy = (x0, y0, x1, y1) color = (125, 123, 114) # color = (0, 0, 0) img = img.copy() PIL.ImageDraw.Draw(img).rectangle(xy, color) return img
[docs]def SamplePairing(imgs): # [0, 0.4] def f(img1, v): i = np.random.choice(len(imgs)) img2 = PIL.Image.fromarray(imgs[i]) return PIL.Image.blend(img1, img2, v) return f
[docs]def augment_list(for_autoaug=True): # 16 oeprations and their ranges l = [ (ShearX, -0.3, 0.3), # 0 (ShearY, -0.3, 0.3), # 1 (TranslateX, -0.45, 0.45), # 2 (TranslateY, -0.45, 0.45), # 3 (Rotate, -30, 30), # 4 (AutoContrast, 0, 1), # 5 (Invert, 0, 1), # 6 (Equalize, 0, 1), # 7 (Solarize, 0, 256), # 8 (Posterize, 4, 8), # 9 (Contrast, 0.1, 1.9), # 10 (Color, 0.1, 1.9), # 11 (Brightness, 0.1, 1.9), # 12 (Sharpness, 0.1, 1.9), # 13 (Cutout, 0, 0.2), # 14 # (SamplePairing(imgs), 0, 0.4), # 15 ] if for_autoaug: l += [ (CutoutAbs, 0, 20), # compatible with auto-augment (Posterize2, 0, 4), # 9 (TranslateXAbs, 0, 10), # 9 (TranslateYAbs, 0, 10), # 9 ] return l
_augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
[docs]def get_augment(name): global _augment_dict return _augment_dict[name]
[docs]def apply_augment(img, name, level): augment_fn, low, high = get_augment(name) return augment_fn(img.copy(), level * (high - low) + low)
[docs]def arsaug_policy(): exp0_0 = [ [('Solarize', 0.66, 0.34), ('Equalize', 0.56, 0.61)], [('Equalize', 0.43, 0.06), ('AutoContrast', 0.66, 0.08)], [('Color', 0.72, 0.47), ('Contrast', 0.88, 0.86)], [('Brightness', 0.84, 0.71), ('Color', 0.31, 0.74)], [('Rotate', 0.68, 0.26), ('TranslateX', 0.38, 0.88)]] exp0_1 = [ [('TranslateY', 0.88, 0.96), ('TranslateY', 0.53, 0.79)], [('AutoContrast', 0.44, 0.36), ('Solarize', 0.22, 0.48)], [('AutoContrast', 0.93, 0.32), ('Solarize', 0.85, 0.26)], [('Solarize', 0.55, 0.38), ('Equalize', 0.43, 0.48)], [('TranslateY', 0.72, 0.93), ('AutoContrast', 0.83, 0.95)]] exp0_2 = [ [('Solarize', 0.43, 0.58), ('AutoContrast', 0.82, 0.26)], [('TranslateY', 0.71, 0.79), ('AutoContrast', 0.81, 0.94)], [('AutoContrast', 0.92, 0.18), ('TranslateY', 0.77, 0.85)], [('Equalize', 0.71, 0.69), ('Color', 0.23, 0.33)], [('Sharpness', 0.36, 0.98), ('Brightness', 0.72, 0.78)]] exp0_3 = [ [('Equalize', 0.74, 0.49), ('TranslateY', 0.86, 0.91)], [('TranslateY', 0.82, 0.91), ('TranslateY', 0.96, 0.79)], [('AutoContrast', 0.53, 0.37), ('Solarize', 0.39, 0.47)], [('TranslateY', 0.22, 0.78), ('Color', 0.91, 0.65)], [('Brightness', 0.82, 0.46), ('Color', 0.23, 0.91)]] exp0_4 = [ [('Cutout', 0.27, 0.45), ('Equalize', 0.37, 0.21)], [('Color', 0.43, 0.23), ('Brightness', 0.65, 0.71)], [('ShearX', 0.49, 0.31), ('AutoContrast', 0.92, 0.28)], [('Equalize', 0.62, 0.59), ('Equalize', 0.38, 0.91)], [('Solarize', 0.57, 0.31), ('Equalize', 0.61, 0.51)]] exp0_5 = [ [('TranslateY', 0.29, 0.35), ('Sharpness', 0.31, 0.64)], [('Color', 0.73, 0.77), ('TranslateX', 0.65, 0.76)], [('ShearY', 0.29, 0.74), ('Posterize', 0.42, 0.58)], [('Color', 0.92, 0.79), ('Equalize', 0.68, 0.54)], [('Sharpness', 0.87, 0.91), ('Sharpness', 0.93, 0.41)]] exp0_6 = [ [('Solarize', 0.39, 0.35), ('Color', 0.31, 0.44)], [('Color', 0.33, 0.77), ('Color', 0.25, 0.46)], [('ShearY', 0.29, 0.74), ('Posterize', 0.42, 0.58)], [('AutoContrast', 0.32, 0.79), ('Cutout', 0.68, 0.34)], [('AutoContrast', 0.67, 0.91), ('AutoContrast', 0.73, 0.83)]] return exp0_0 + exp0_1 + exp0_2 + exp0_3 + exp0_4 + exp0_5 + exp0_6
[docs]def autoaug2arsaug(f): def autoaug(): mapper = defaultdict(lambda: lambda x: x) mapper.update({ 'ShearX': lambda x: float_parameter(x, 0.3), 'ShearY': lambda x: float_parameter(x, 0.3), 'TranslateX': lambda x: int_parameter(x, 10), 'TranslateY': lambda x: int_parameter(x, 10), 'Rotate': lambda x: int_parameter(x, 30), 'Solarize': lambda x: 256 - int_parameter(x, 256), 'Posterize2': lambda x: 4 - int_parameter(x, 4), 'Contrast': lambda x: float_parameter(x, 1.8) + .1, 'Color': lambda x: float_parameter(x, 1.8) + .1, 'Brightness': lambda x: float_parameter(x, 1.8) + .1, 'Sharpness': lambda x: float_parameter(x, 1.8) + .1, 'CutoutAbs': lambda x: int_parameter(x, 20) }) def low_high(name, prev_value): _, low, high = get_augment(name) return float(prev_value - low) / (high - low) policies = f() new_policies = [] for policy in policies: new_policies.append([(name, pr, low_high(name, mapper[name](level))) for name, pr, level in policy]) return new_policies return autoaug
[docs]@autoaug2arsaug def autoaug_paper_cifar10(): return [ [('Invert', 0.1, 7), ('Contrast', 0.2, 6)], [('Rotate', 0.7, 2), ('TranslateXAbs', 0.3, 9)], [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)], [('ShearY', 0.5, 8), ('TranslateYAbs', 0.7, 9)], [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)], [('ShearY', 0.2, 7), ('Posterize2', 0.3, 7)], [('Color', 0.4, 3), ('Brightness', 0.6, 7)], [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)], [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)], [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)], [('Color', 0.7, 7), ('TranslateXAbs', 0.5, 8)], [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)], [('TranslateYAbs', 0.4, 3), ('Sharpness', 0.2, 6)], [('Brightness', 0.9, 6), ('Color', 0.2, 6)], [('Solarize', 0.5, 2), ('Invert', 0.0, 3)], [('Equalize', 0.2, 0), ('AutoContrast', 0.6, 0)], [('Equalize', 0.2, 8), ('Equalize', 0.6, 4)], [('Color', 0.9, 9), ('Equalize', 0.6, 6)], [('AutoContrast', 0.8, 4), ('Solarize', 0.2, 8)], [('Brightness', 0.1, 3), ('Color', 0.7, 0)], [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)], [('TranslateYAbs', 0.9, 9), ('TranslateYAbs', 0.7, 9)], [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)], [('Equalize', 0.8, 8), ('Invert', 0.1, 3)], [('TranslateYAbs', 0.7, 9), ('AutoContrast', 0.9, 1)], ]
[docs]@autoaug2arsaug def autoaug_policy(): """AutoAugment policies found on Cifar.""" exp0_0 = [ [('Invert', 0.1, 7), ('Contrast', 0.2, 6)], [('Rotate', 0.7, 2), ('TranslateXAbs', 0.3, 9)], [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)], [('ShearY', 0.5, 8), ('TranslateYAbs', 0.7, 9)], [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]] exp0_1 = [ [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)], [('TranslateYAbs', 0.9, 9), ('TranslateYAbs', 0.7, 9)], [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)], [('Equalize', 0.8, 8), ('Invert', 0.1, 3)], [('TranslateYAbs', 0.7, 9), ('AutoContrast', 0.9, 1)]] exp0_2 = [ [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)], [('TranslateYAbs', 0.7, 9), ('TranslateYAbs', 0.7, 9)], [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)], [('Equalize', 0.7, 5), ('Invert', 0.1, 3)], [('TranslateYAbs', 0.7, 9), ('TranslateYAbs', 0.7, 9)]] exp0_3 = [ [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)], [('TranslateYAbs', 0.8, 9), ('TranslateYAbs', 0.9, 9)], [('AutoContrast', 0.8, 0), ('TranslateYAbs', 0.7, 9)], [('TranslateYAbs', 0.2, 7), ('Color', 0.9, 6)], [('Equalize', 0.7, 6), ('Color', 0.4, 9)]] exp1_0 = [ [('ShearY', 0.2, 7), ('Posterize2', 0.3, 7)], [('Color', 0.4, 3), ('Brightness', 0.6, 7)], [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)], [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)], [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]] exp1_1 = [ [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)], [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)], [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)], [('TranslateYAbs', 0.2, 4), ('Sharpness', 0.3, 3)], [('Brightness', 0.0, 8), ('Color', 0.8, 8)]] exp1_2 = [ [('Solarize', 0.2, 6), ('Color', 0.8, 6)], [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)], [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)], [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)], [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]] exp1_3 = [ [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)], [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)], [('Contrast', 0.5, 1), ('TranslateYAbs', 0.2, 9)], [('AutoContrast', 0.6, 5), ('TranslateYAbs', 0.0, 9)], [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]] exp1_4 = [ [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)], [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)], [('Equalize', 0.6, 8), ('Color', 0.6, 2)], [('Color', 0.3, 7), ('Color', 0.2, 4)], [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]] exp1_5 = [ [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)], [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)], [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)], [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)], [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]] exp1_6 = [ [('Equalize', 0.8, 4), ('TranslateYAbs', 0.8, 9)], [('TranslateYAbs', 0.8, 9), ('TranslateYAbs', 0.6, 9)], [('TranslateYAbs', 0.9, 0), ('TranslateYAbs', 0.5, 9)], [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)], [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]] exp2_0 = [ [('Color', 0.7, 7), ('TranslateXAbs', 0.5, 8)], [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)], [('TranslateYAbs', 0.4, 3), ('Sharpness', 0.2, 6)], [('Brightness', 0.9, 6), ('Color', 0.2, 8)], [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]] exp2_1 = [ [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)], [('CutoutAbs', 0.2, 4), ('Equalize', 0.1, 1)], [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)], [('Color', 0.1, 8), ('ShearY', 0.2, 3)], [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]] exp2_2 = [ [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)], [('TranslateYAbs', 0.3, 6), ('CutoutAbs', 0.3, 3)], [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)], [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)], [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]] exp2_3 = [ [('Equalize', 0.9, 5), ('Color', 0.7, 0)], [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)], [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)], [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)], [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]] exp2_4 = [ [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)], [('TranslateXAbs', 0.3, 0), ('TranslateXAbs', 0.6, 0)], [('Equalize', 0.5, 9), ('TranslateYAbs', 0.6, 7)], [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)], [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]] exp2_5 = [ [('AutoContrast', 0.3, 9), ('CutoutAbs', 0.5, 3)], [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)], [('ShearX', 0.0, 3), ('Posterize2', 0.0, 3)], [('Solarize', 0.4, 3), ('Color', 0.2, 4)], [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]] exp2_6 = [ [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)], [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)], [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)], [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)], [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]] exp2_7 = [ [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)], [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)], [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)], [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)], [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]] exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3 exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6 exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7 return exp0s + exp1s + exp2s
_PARAMETER_MAX = 10
[docs]def float_parameter(level, maxval): return float(level) * maxval / _PARAMETER_MAX
[docs]def int_parameter(level, maxval): return int(float_parameter(level, maxval))
[docs]def no_duplicates(f): def wrap_remove_duplicates(): policies = f() return remove_deplicates(policies) return wrap_remove_duplicates
[docs]def remove_deplicates(policies): s = set() new_policies = [] for ops in policies: key = [] for op in ops: key.append(op[0]) key = '_'.join(key) if key in s: continue else: s.add(key) new_policies.append(ops) return new_policies
[docs]def policy_decoder(augment, num_policy, num_op): op_list = augment_list(False) policies = [] for i in range(num_policy): ops = [] for j in range(num_op): op_idx = augment['policy_%d_%d' % (i, j)] op_prob = augment['prob_%d_%d' % (i, j)] op_level = augment['level_%d_%d' % (i, j)] ops.append((op_list[op_idx][0].__name__, op_prob, op_level)) policies.append(ops) return policies