Source code for archai.supergraph.nas.dag_edge

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import List, Optional

import torch
from overrides import overrides
from torch import nn

from archai.supergraph.nas.arch_module import ArchModule
from archai.supergraph.nas.model_desc import EdgeDesc
from archai.supergraph.nas.operations import DropPath_, Op


[docs]class DagEdge(ArchModule): def __init__(self, desc:EdgeDesc, affine:bool, droppath:bool, template_edge:Optional['DagEdge'])->None: super().__init__() # we may need to wrap op is droppath is needed self._wrapped = self._op = Op.create(desc.op_desc, affine, template_edge.op().arch_params() if template_edge is not None else None) if droppath and self._op.can_drop_path(): assert self.training self._wrapped = nn.Sequential(self._op, DropPath_()) self.input_ids = desc.input_ids self.desc = desc
[docs] @overrides def forward(self, inputs:List[torch.Tensor]): if len(self.input_ids)==1: return self._wrapped(inputs[self.input_ids[0]]) elif len(self.input_ids) == len(inputs): # for perf return self._wrapped(inputs) else: return self._wrapped([inputs[i] for i in self.input_ids])
[docs] def op(self)->Op: return self._op