Source code for archai.supergraph.nas.dag_edge
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Optional
import torch
from overrides import overrides
from torch import nn
from archai.supergraph.nas.arch_module import ArchModule
from archai.supergraph.nas.model_desc import EdgeDesc
from archai.supergraph.nas.operations import DropPath_, Op
[docs]class DagEdge(ArchModule):
def __init__(self, desc:EdgeDesc, affine:bool, droppath:bool,
template_edge:Optional['DagEdge'])->None:
super().__init__()
# we may need to wrap op is droppath is needed
self._wrapped = self._op = Op.create(desc.op_desc, affine,
template_edge.op().arch_params() if template_edge is not None else None)
if droppath and self._op.can_drop_path():
assert self.training
self._wrapped = nn.Sequential(self._op, DropPath_())
self.input_ids = desc.input_ids
self.desc = desc
[docs] @overrides
def forward(self, inputs:List[torch.Tensor]):
if len(self.input_ids)==1:
return self._wrapped(inputs[self.input_ids[0]])
elif len(self.input_ids) == len(inputs): # for perf
return self._wrapped(inputs)
else:
return self._wrapped([inputs[i] for i in self.input_ids])
[docs] def op(self)->Op:
return self._op