# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import List, Optional, Tuple
from overrides import EnforceOverrides
from archai.common.config import Config
from archai.supergraph.nas.model_desc import (
AuxTowerDesc,
CellDesc,
CellType,
ConvMacroParams,
ModelDesc,
NodeDesc,
OpDesc,
TensorShape,
TensorShapes,
TensorShapesList,
)
from archai.supergraph.nas.operations import Op, StemBase
[docs]class ModelDescBuilder(EnforceOverrides):
[docs] def get_reduction_indices(self, conf_model_desc:Config)->List[int]:
""" Returns cell indices which reduces HxW and doubles channels """
n_cells:int = conf_model_desc['n_cells']
n_reductions:int = conf_model_desc['n_reductions']
# this satisfies N R N R N pattern, this need not be enforced but
# we are doing now for sanity
assert n_cells >= n_reductions * 2 + 1
# for each reduction, we create one indice
# for cifar and imagenet, reductions=2 creating cuts at n//3, n*2//3
return list(n_cells*(i+1) // (n_reductions+1) \
for i in range(n_reductions))
[docs] def get_node_channels(self, conf_model_desc:Config)->List[List[int]]:
""" Returns array of channels for each node in each cell. All nodes
are assumed to have same output channels as input channels. """
conf_model_stems = self.get_conf_model_stems()
conf_cell = self.get_conf_cell()
init_node_ch:int = conf_model_stems['init_node_ch']
n_cells = conf_model_desc['n_cells']
n_nodes = conf_cell['n_nodes']
# same channels for all nodes in a cell
cell_node_channels:List[List[int]] = []
# channels for the first cell
node_channels = init_node_ch
for ci in range(n_cells):
# if reduction cell than double the node channels
if self.get_cell_type(ci)==CellType.Reduction:
node_channels *= 2
# all nodes in a cell have same channels
nodes_channels = [node_channels for ni in range(n_nodes)]
cell_node_channels.append(nodes_channels
)
return cell_node_channels
[docs] def get_conf_cell(self)->Config:
return self.conf_model_desc['cell']
[docs] def get_conf_dataset(self)->Config:
return self.conf_model_desc['dataset']
[docs] def get_conf_model_stems(self)->Config:
return self.conf_model_desc['model_stems']
def _init_build(self, conf_model_desc: Config,
template:Optional[ModelDesc]=None)->None:
self.conf_model_desc = conf_model_desc
self.template = template
# if template model desc is specified then setup regular and reduction cell templates
self._cell_templates = self.create_cell_templates(template)
n_cells = conf_model_desc['n_cells']
# for each reduction, we create one indice
# for cifar and imagenet, reductions=2 creating cuts at n//3, n*2//3
self._reduction_indices = self.get_reduction_indices(conf_model_desc)
self._normal_indices = [i for i in range(n_cells)\
if i not in self._reduction_indices]
self.node_channels = self.get_node_channels(conf_model_desc)
[docs] def build(self, conf_model_desc: Config,
template:Optional[ModelDesc]=None)->ModelDesc:
"""main entry point for the class"""
self._init_build(conf_model_desc, template)
self.pre_build(conf_model_desc)
# input shape for the stem has same channels as channels in image
# -1 indicates, actual dimensions are not known
ds_ch = self.get_conf_dataset()['channels']
in_shapes = [[[ds_ch, -1, -1, -1]]]
# create model stems
model_stems = self.build_model_stems(in_shapes, conf_model_desc)
# create cell descriptions
cell_descs, aux_tower_descs = self.build_cells(in_shapes, conf_model_desc)
model_pool_op = self.build_model_pool(in_shapes, conf_model_desc)
logits_op = self.build_logits_op(in_shapes, conf_model_desc)
return ModelDesc(conf_model_desc, model_stems, model_pool_op, cell_descs,
aux_tower_descs, logits_op)
[docs] def build_cells(self, in_shapes:TensorShapesList, conf_model_desc:Config)\
->Tuple[List[CellDesc], List[Optional[AuxTowerDesc]]]:
conf_cell = self.get_conf_cell()
n_cells = conf_model_desc['n_cells']
cell_descs, aux_tower_descs = [], []
# create list of output shapes for cells that starts with model stem
for ci in range(n_cells):
cell_desc = self.build_cell(in_shapes, conf_cell, ci)
# get first tensor output of last cell
aux_tower_desc = self.build_aux_tower(in_shapes[-1][0], conf_model_desc, ci)
cell_descs.append(cell_desc)
aux_tower_descs.append(aux_tower_desc)
return cell_descs, aux_tower_descs
[docs] def get_node_count(self, cell_index:int)->int:
return len(self.node_channels[cell_index])
[docs] def build_cell(self, in_shapes:TensorShapesList, conf_cell:Config,
cell_index:int) ->CellDesc:
stem_shapes, stems = self.build_cell_stems(in_shapes, conf_cell, cell_index)
cell_type = self.get_cell_type(cell_index)
if self.template is None:
node_count = self.get_node_count(cell_index)
in_shape = stem_shapes[0] # input shape to noded is same as cell stem
out_shape = stem_shapes[0] # we ask nodes to keep the output shape same
node_shapes, nodes = self.build_nodes(stem_shapes, conf_cell,
cell_index, cell_type, node_count, in_shape, out_shape)
else:
node_shapes, nodes = self.build_nodes_from_template(stem_shapes, conf_cell, cell_index)
post_op_shape, post_op_desc = self.build_cell_post_op(stem_shapes,
node_shapes, conf_cell, cell_index)
cell_desc = CellDesc(
id=cell_index, cell_type=self.get_cell_type(cell_index),
conf_cell=conf_cell,
stems=stems, stem_shapes=stem_shapes,
nodes=nodes, node_shapes=node_shapes,
post_op=post_op_desc, out_shape=post_op_shape,
trainables_from=self.get_trainables_from(cell_index)
)
# output same shape twice to indicate s0 and s1 inputs for next cell
in_shapes.append([post_op_shape])
return cell_desc
[docs] def get_trainables_from(self, cell_index:int)->int:
cell_type = self.get_cell_type(cell_index)
if cell_type == CellType.Reduction:
return self._reduction_indices[0]
if cell_type == CellType.Regular:
return self._normal_indices[0]
raise RuntimeError(f'Cannot get cell for shared trainables because cell_type "{cell_type}" is not recgnized')
[docs] def get_ch(self, shape:TensorShape)->int:
return int(shape[0])
[docs] def build_cell_stems(self, in_shapes:TensorShapesList, conf_cell:Config,
cell_index:int)\
->Tuple[TensorShapes, List[OpDesc]]:
# expect two stems, both should have same channels
# TODO: support multiple stems
assert len(in_shapes) >= 2, "we must have outputs from at least two previous modules"
# Get channels for previous two layers.
# At start we have only one layer, i.e., model stems.
# Typically model stems should have same channel count but for imagenet we do
# reduction at model stem so stem1 will have twice channels as stem0
p_ch_out = self.get_ch(in_shapes[-1][0])
pp_ch_out = self.get_ch(in_shapes[-2][0])
# was the previous layer reduction layer?
reduction_p = p_ch_out == pp_ch_out*2 or in_shapes[-2][0][2] == in_shapes[-1][0][2]*2
# find out the node channels for this cell
node_ch_out = self.node_channels[cell_index][0] # init with first node in cell
# Cell stemps will take prev channels and out sameput channels as nodes would.
# If prev cell was reduction then we need to increase channels of prev-prev
# by 2X. This is done by prepr_reduce stem.
s0_op = OpDesc('prepr_reduce' if reduction_p else 'prepr_normal',
params={
'conv': ConvMacroParams(pp_ch_out, node_ch_out)
}, in_len=1, trainables=None)
s1_op = OpDesc('prepr_normal',
params={
'conv': ConvMacroParams(p_ch_out, node_ch_out)
}, in_len=1, trainables=None)
# output two shapes with proper channels setup
# for default model desc, cell stems have same shapes and channels
out_shape0 = copy.deepcopy(in_shapes[-1][0])
# set channels and reset shapes to -1 to indicate unknown
# for imagenet HxW would be floating point numbers due to one input reduced
out_shape0[0], out_shape0[2], out_shape0[3] = node_ch_out, -1, -1
out_shape1 = copy.deepcopy(out_shape0)
return [out_shape0, out_shape1], [s0_op, s1_op]
[docs] def build_nodes_from_template(self, stem_shapes:TensorShapes, conf_cell:Config,
cell_index:int) \
->Tuple[TensorShapes, List[NodeDesc]]:
cell_template = self.get_cell_template(cell_index)
assert cell_template is not None
cell_type = self.get_cell_type(cell_index)
assert cell_template.cell_type==cell_type
nodes:List[NodeDesc] = []
for n in cell_template.nodes():
edges_copy = [e.clone(
# use new macro params
conv_params=ConvMacroParams(self.get_ch(stem_shapes[0]),
self.get_ch(stem_shapes[0])),
# TODO: check for compatibility?
clear_trainables=True
) for e in n.edges]
nodes.append(NodeDesc(edges=edges_copy, conv_params=n.conv_params))
out_shapes = [copy.deepcopy(stem_shapes[0]) for _ in cell_template.nodes()]
return out_shapes, nodes
[docs] def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config,
cell_index:int, cell_type:CellType, node_count:int,
in_shape:TensorShape, out_shape:TensorShape) \
->Tuple[TensorShapes, List[NodeDesc]]:
# default: create nodes with empty edges
nodes:List[NodeDesc] = [NodeDesc(edges=[],
conv_params=ConvMacroParams(
self.get_ch(in_shape),
self.get_ch(out_shape)))
for _ in range(node_count)]
out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)]
return out_shapes, nodes
[docs] def create_cell_templates(self, template:Optional[ModelDesc])\
->List[Optional[CellDesc]]:
normal_template, reduction_template = None, None
if template is not None:
# find first regular and reduction cells and set them as
# the template that we will use. When we create new cells
# we will fill them up with nodes from these templates
for cell_desc in template.cell_descs():
if normal_template is None and \
cell_desc.cell_type==CellType.Regular:
normal_template = cell_desc
if reduction_template is None and \
cell_desc.cell_type==CellType.Reduction:
reduction_template = cell_desc
return [normal_template, reduction_template]
[docs] def build_model_pool(self, in_shapes:TensorShapesList, conf_model_desc:Config)\
->OpDesc:
model_post_op = conf_model_desc['model_post_op']
last_shape = in_shapes[-1][0]
in_shapes.append([copy.deepcopy(last_shape)])
return OpDesc(model_post_op,
params={'conv': ConvMacroParams(self.get_ch(last_shape),
self.get_ch(last_shape))},
in_len=1, trainables=None)
[docs] def build_logits_op(self, in_shapes:TensorShapesList, conf_model_desc:Config)->OpDesc:
n_classes = self.get_conf_dataset()['n_classes']
return OpDesc('linear',
params={'n_ch':in_shapes[-1][0][0],
'n_classes': n_classes},
in_len=1, trainables=None)
[docs] def get_cell_template(self, cell_index:int)->Optional[CellDesc]:
cell_type = self.get_cell_type(cell_index)
if cell_type==CellType.Regular:
return self._cell_templates[0]
if cell_type==CellType.Reduction:
return self._cell_templates[1]
raise RuntimeError(f'Cannot get cell template because cell_type "{cell_type}" is not recgnized')
[docs] def get_cell_type(self, cell_index:int)->CellType:
# For darts, n_cells=8 so we build [N N R N N R N N] structure
# Notice that this will result in only 2 reduction cells no matter
# total number of cells. Original resnet actually have 3 reduction cells.
# Between two reduction cells we have regular cells.
return CellType.Reduction if cell_index in self._reduction_indices \
else CellType.Regular
def _post_op_ch(self, post_op_name:str, node_shapes:TensorShapes) \
->Tuple[int, int, int]:
node_count = len(node_shapes)
node_ch_out = self.get_ch(node_shapes[-1])
# we take all available node outputs as input to post op
# if no nodes exist then we will use cell stem outputs
# Note that for reduction cell stems wxh is larger than node wxh which
# means we cannot use cell stem outputs with node outputs because
# concate will fail
# TODO: remove hard coding of 2
out_states = node_count if node_count else 2
# number of input channels to the cell post op
op_ch_in = out_states * node_ch_out
# number of output channels for the cell post op
if post_op_name == 'concate_channels':
cell_ch_out = op_ch_in
elif post_op_name == 'proj_channels':
cell_ch_out = node_ch_out
else:
raise RuntimeError(f'Unsupported cell_post_op: {post_op_name}')
return op_ch_in, cell_ch_out, out_states
[docs] def build_cell_post_op(self, stem_shapes:TensorShapes,
node_shapes:TensorShapes, conf_cell:Config, cell_index:int)\
-> Tuple[TensorShape, OpDesc]:
post_op_name = conf_cell['cell_post_op']
op_ch_in, cell_ch_out, out_states = self._post_op_ch(post_op_name,
node_shapes)
post_op_desc = OpDesc(post_op_name,
{
'conv': ConvMacroParams(op_ch_in, cell_ch_out),
'out_states': out_states
},
in_len=1, trainables=None, children=None)
out_shape = copy.deepcopy(node_shapes[-1])
out_shape[0] = cell_ch_out
return out_shape, post_op_desc
[docs] def build_aux_tower(self, out_shape:TensorShape, conf_model_desc:Config,
cell_index:int)->Optional[AuxTowerDesc]:
n_classes = self.get_conf_dataset()['n_classes']
n_cells = conf_model_desc['n_cells']
n_reductions = conf_model_desc['n_reductions']
aux_tower_stride = conf_model_desc['aux_tower_stride']
aux_weight = conf_model_desc['aux_weight']
# TODO: shouldn't we be adding aux tower at *every* 1/3rd?
if aux_weight and n_reductions > 1 and cell_index == 2*n_cells//3:
return AuxTowerDesc(self.get_ch(out_shape), n_classes, aux_tower_stride)
return None
[docs] def build_model_stems(self, in_shapes:TensorShapesList,
conf_model_desc:Config)->List[OpDesc]:
# TODO: why do we need stem_multiplier?
# TODO: in original paper stems are always affine
conf_model_stems = self.get_conf_model_stems()
init_node_ch:int = conf_model_stems['init_node_ch']
stem_multiplier:int = conf_model_stems['stem_multiplier']
ops:List[str] = conf_model_stems['ops']
out_channels = init_node_ch*stem_multiplier
conv_params = ConvMacroParams(self.get_ch(in_shapes[-1][0]), # channels of first input tensor
init_node_ch*stem_multiplier)
stems = [OpDesc(name=op_name, params={'conv': conv_params},
in_len=1, trainables=None) \
for op_name in ops]
# get reduction factors done by each stem, typically they should be same but for
# imagenet they can differ
stem_reductions = ModelDescBuilder._stem_reductions(stems)
# Each cell takes input from previous and 2nd previous cells.
# To be consistence we create two outputs for model stems: [[s1, s0], [s0, s1]
# This way when we access first element of each output we get s1, s0.
# Normailly s0==s1 but for networks like imagenet, s0 will have twice the channels
# of s1.
for stem_reduction in stem_reductions:
in_shapes.append([[out_channels, -1, -1.0/stem_reduction, -1.0/stem_reduction]])
return stems
@staticmethod
def _stem_reductions(stems:List[OpDesc])->List[int]:
# create stem ops to find out reduction factors
ops = [Op.create(stem, affine=False) for stem in stems]
assert all(isinstance(op, StemBase) for op in ops)
return list(op.reduction for op in ops)
[docs] def pre_build(self, conf_model_desc:Config)->None:
"""hook for accomplishing any setup before build starts"""
pass
[docs] def seed_cell(self, model_desc:ModelDesc)->None:
# prepare model as seed model before search iterations starts
pass