Source code for archai.supergraph.algos.divnas.divop
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Iterator, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from overrides import overrides
from torch import nn
from archai.common.common import get_conf
from archai.common.utils import zip_eq
from archai.supergraph.nas.arch_params import ArchParams
from archai.supergraph.nas.model_desc import OpDesc
from archai.supergraph.nas.operations import Op
# TODO: reduction cell might have output reduced by 2^1=2X due to
# stride 2 through input nodes however FactorizedReduce does only
# 4X reduction. Is this correct?
[docs]class DivOp(Op):
"""The output of DivOp is weighted output of all allowed primitives.
"""
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'none' # this must be at the end so top1 doesn't choose it
]
# def _indices_of_notallowed(self):
# ''' computes indices of notallowed ops in PRIMITIVES '''
# self._not_allowed_indices = []
# for op_name in self.NOTALLOWED:
# self._not_allowed_indices.append(self.PRIMITIVES.index(op_name))
# self._not_allowed_indices = sorted(self._not_allowed_indices, reverse=True)
# def _create_mapping_valid_to_orig(self):
# ''' Creates a list with indices of the valid ops to the original list '''
# self._valid_to_orig = []
# for i, prim in enumerate(self.PRIMITIVES):
# if prim in self.NOTALLOWED:
# continue
# else:
# self._valid_to_orig.append(i)
def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
affine:bool):
super().__init__()
# assume last PRIMITIVE is 'none'
assert DivOp.PRIMITIVES[-1] == 'none'
conf = get_conf()
trainer = conf['nas']['search']['divnas']['archtrainer']
finalizer = conf['nas']['search']['finalizer']
if trainer == 'noalpha' and finalizer == 'default':
raise NotImplementedError('noalpha trainer is not implemented for the default finalizer')
if trainer != 'noalpha':
self._setup_arch_params(arch_params)
else:
self._alphas = None
self._ops = nn.ModuleList()
for primitive in DivOp.PRIMITIVES:
op = Op.create(
OpDesc(primitive, op_desc.params, in_len=1, trainables=None),
affine=affine, arch_params=None)
self._ops.append(op)
# various state variables for diversity
self._collect_activations = False
self._forward_counter = 0
self._batch_activs = None
#self._indices_of_notallowed()
#self._create_mapping_valid_to_orig()
@property
def collect_activations(self)->bool:
return self._collect_activations
@collect_activations.setter
def collect_activations(self, to_collect:bool)->None:
self._collect_activations = to_collect
@property
def activations(self)->Optional[List[np.array]]:
return self._batch_activs
@property
def num_primitive_ops(self)->int:
return len(self.PRIMITIVES)
[docs] @overrides
def forward(self, x):
# save activations to object
if self._collect_activations:
self._forward_counter += 1
activs = [op(x) for op in self._ops]
# delete the activation for none type
# as we don't consider it
activs = activs[:-1]
self._batch_activs = [t.cpu().detach().numpy() for t in activs]
if self._alphas:
asm = F.softmax(self._alphas[0], dim=0)
result = sum(w * op(x) for w, op in zip(asm, self._ops))
else:
result = sum(op(x) for op in self._ops)
return result
[docs] @overrides
def ops(self)->Iterator[Tuple['Op', float]]: # type: ignore
return iter(sorted(zip_eq(self._ops,
self._alphas[0] if self._alphas is not None else [math.nan for _ in range(len(self._ops))]),
key=lambda t:t[1], reverse=True))
# def get_valid_op_desc(self, index:int)->OpDesc:
# ''' index: index in the valid index list '''
# assert index <= self.num_valid_div_ops
# orig_index = self._valid_to_orig[index]
# desc, _ = self._ops[orig_index].finalize()
# return desc
[docs] @overrides
def finalize(self) -> Tuple[OpDesc, Optional[float]]:
''' Divnas with default finalizer option needs this override else
the finalizer in base class returns the whole divop '''
with torch.no_grad():
# select except 'none' op
val, i = torch.topk(self._alphas[0][:-1], 1)
desc, _ = self._ops[i].finalize()
return desc, float(val.item())
[docs] @overrides
def can_drop_path(self) -> bool:
return False
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None:
# do we have shared arch params?
if arch_params is None:
# create our own arch params
new_p = nn.Parameter( # TODO: use better init than uniform random?
1.0e-3*torch.randn(len(self.PRIMITIVES)), requires_grad=True)
self.create_arch_params([('alphas', new_p)])
else:
assert arch_params.has_kind('alphas')
self.set_arch_params(arch_params)
# we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().param_by_kind('alphas'))
assert len(self._alphas)==1