Source code for archai.supergraph.algos.darts.mixed_op

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Iterator, Optional, Tuple

import torch
import torch.nn.functional as F
from overrides import overrides
from torch import nn

from archai.common.utils import zip_eq
from archai.supergraph.nas.arch_params import ArchParams
from archai.supergraph.nas.model_desc import OpDesc
from archai.supergraph.nas.operations import Op

# TODO: reduction cell might have output reduced by 2^1=2X due to
#   stride 2 through input nodes however FactorizedReduce does only
#   4X reduction. Is this correct?


[docs]class MixedOp(Op): """The output of MixedOp is weighted output of all allowed primitives. """ PRIMITIVES = [ 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', # identity 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', 'none' # this must be at the end so top1 doesn't chose it ] def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams], affine:bool): super().__init__() # assume last PRIMITIVE is 'none' assert MixedOp.PRIMITIVES[-1] == 'none' self._ops = nn.ModuleList() for primitive in MixedOp.PRIMITIVES: op = Op.create( OpDesc(primitive, op_desc.params, in_len=1, trainables=None), affine=affine, arch_params=None) self._ops.append(op) # we do this at the end so that we can capture all arch params registered by # any previous child modules self._setup_arch_params(arch_params)
[docs] @overrides def forward(self, x): asm = F.softmax(self._alphas[0], dim=0) return sum(w * op(x) for w, op in zip(asm, self._ops))
[docs] @overrides def finalize(self) -> Tuple[OpDesc, Optional[float]]: with torch.no_grad(): # select except 'none' op val, i = torch.topk(self._alphas[0][:-1], 1) desc, _ = self._ops[i].finalize() return desc, float(val.item())
[docs] @overrides def can_drop_path(self) -> bool: return False
[docs] @overrides def ops(self)->Iterator[Tuple['Op', float]]: # type: ignore return iter(sorted(zip_eq(self._ops, self._alphas[0]), key=lambda t:t[1], reverse=True))
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None: # do we have shared arch params? if arch_params is None: # create our own arch params new_p = nn.Parameter( # TODO: use better init than uniform random? 1.0e-3*torch.randn(len(MixedOp.PRIMITIVES)), requires_grad=True) self.create_arch_params([('alphas', new_p)]) else: assert arch_params.has_kind('alphas') self.set_arch_params(arch_params) # we store alphas in list so Pytorch don't register them self._alphas = list(self.arch_params().param_by_kind('alphas')) assert len(self._alphas)==1