Source code for archai.supergraph.algos.petridish.petridish_model_desc_builder
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from overrides import overrides
from archai.common.config import Config
from archai.supergraph.algos.petridish.petridish_op import PetridishOp, TempIdentityOp
from archai.supergraph.algos.random.random_model_desc_builder import (
RandomModelDescBuilder,
)
from archai.supergraph.nas.operations import Op
[docs]class PetridishModelBuilder(RandomModelDescBuilder):
[docs] @overrides
def pre_build(self, conf_model_desc:Config)->None:
super().pre_build(conf_model_desc)
Op.register_op('petridish_normal_op',
lambda op_desc, arch_params, affine:
PetridishOp(op_desc, arch_params, False, affine))
Op.register_op('petridish_reduction_op',
lambda op_desc, arch_params, affine:
PetridishOp(op_desc, arch_params, True, affine))
Op.register_op('temp_identity_op',
lambda op_desc, arch_params, affine:
TempIdentityOp(op_desc))
# @overrides
# def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config,
# cell_index:int, cell_type:CellType, node_count:int,
# in_shape:TensorShape, out_shape:TensorShape) \
# ->Tuple[TensorShapes, List[NodeDesc]]:
# # For petridish we add one node with identity to s1.
# # This will be our seed model to start with.
# # Later in PetridishSearcher, we will add one more node in parent after each sampling.
# assert in_shape[0]==out_shape[0]
# reduction = (cell_type==CellType.Reduction)
# # channels for conv filters
# conv_params = ConvMacroParams(in_shape[0], out_shape[0])
# # identity op to connect S1 to the node
# op_desc = OpDesc('skip_connect',
# params={'conv': conv_params,
# 'stride': 2 if reduction else 1},
# in_len=1, trainables=None, children=None)
# edge = EdgeDesc(op_desc, input_ids=[1])
# new_node = NodeDesc(edges=[edge], conv_params=conv_params)
# nodes = [new_node]
# # each node has same out channels as in channels
# out_shapes = [copy.deepcopy(out_shape) for _ in nodes]
# return out_shapes, nodes