Source code for archai.supergraph.nas.cell
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterable, List, Optional
from overrides import EnforceOverrides, overrides
from torch import nn
from archai.supergraph.nas.arch_module import ArchModule
from archai.supergraph.nas.dag_edge import DagEdge
from archai.supergraph.nas.model_desc import CellDesc, NodeDesc
from archai.supergraph.nas.operations import Op
[docs]class Cell(ArchModule, EnforceOverrides):
def __init__(self, desc:CellDesc,
affine:bool, droppath:bool,
trainables_from:Optional['Cell']): # template cell, if any, to use for arch params
super().__init__()
# some of these members are public as finalizer needs access
self.desc = desc
# TODO: support any number of stems
assert len(desc.stems)==2, "Cell compiler currently only supports 2 stems"
self.s0_op = Op.create(desc.stems[0], affine=affine)
self.s1_op = Op.create(desc.stems[1], affine=affine)
self.dag = Cell._create_dag(desc.nodes(),
affine=affine, droppath=droppath,
trainables_from=trainables_from)
self.post_op = Op.create(desc.post_op, affine=affine)
@staticmethod
def _create_dag(nodes_desc:List[NodeDesc],
affine:bool, droppath:bool,
trainables_from:Optional['Cell'])->nn.ModuleList:
dag = nn.ModuleList()
for i, node_desc in enumerate(nodes_desc):
edges:nn.ModuleList = nn.ModuleList()
dag.append(edges)
# assert len(node_desc.edges) > 0
for j, edge_desc in enumerate(node_desc.edges):
edges.append(DagEdge(edge_desc,
affine=affine, droppath=droppath,
template_edge=trainables_from.dag[i][j] if trainables_from else None))
return dag
[docs] def ops(self)->Iterable[Op]:
for node in self.dag:
for edge in node:
yield edge.op()
[docs] @overrides
def forward(self, s0, s1):
s0 = self.s0_op(s0)
s1 = self.s1_op(s1)
states = [s0, s1]
for node in self.dag:
# TODO: we should probably do average here otherwise output will
# blow up as number of primitives grows
# TODO: Current assumption is that each edge has k channel
# output so node output is k channel as well
# This won't allow for arbitrary edges.
if len(node):
o = sum(edge(states) for edge in node)
else:
# support zero edges node by assuming zero op from last state
o = states[-1] + 0.0
states.append(o)
# TODO: Below assumes same shape except for channels but this won't
# happen for max pool etc shapes? Also, remove hard coded 2.
return self.post_op(states)