# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Note: All classes in this file needs to be deepcopy compatible because
descs are used as template to create copies by macro builder.
"""
import copy
import os
import pathlib
from enum import Enum
from typing import List, Mapping, Optional, Union
import torch
import yaml
from archai.common import utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
logger = get_global_logger()
# Each tensor shape is list
# A layer can output multiple tensors so its shapes are TensorShapes
# list of all layer outputs is TensorShapesList]
TensorShape=List[Union[int, float]]
TensorShapes=List[TensorShape]
TensorShapesList=List[TensorShapes]
[docs]class ConvMacroParams:
"""Holds parameters that may be altered by macro architecture"""
def __init__(self, ch_in:int, ch_out:int) -> None:
self.ch_in, self.ch_out = ch_in, ch_out
[docs] def clone(self)->'ConvMacroParams':
return copy.deepcopy(self)
[docs]class OpDesc:
"""Op description that is in each edge"""
def __init__(self, name:str, params:dict, in_len:int,
trainables:Optional[Mapping],
children:Optional[List['OpDesc']]=None,
children_ins:Optional[List[int]]=None)->None:
self.name = name
self.in_len = in_len
self.params = params # parameters specific to op needed to construct it
self.trainables = trainables # TODO: make this private due to clear_trainable
# If op is keeping any child op then it should save it in children.
# This way we can control state_dict of children.
self.children = children
self.children_ins = children_ins
[docs] def clone(self, clone_trainables=True)->'OpDesc':
cloned = copy.deepcopy(self)
if not clone_trainables:
cloned.clear_trainables()
return cloned
[docs] def clear_trainables(self)->None:
self.trainables = None
if self.children is not None:
for child in self.children:
child.clear_trainables()
[docs] def state_dict(self)->dict:
return {
'trainables': self.trainables,
'children': [child.state_dict() if child is not None else None
for child in self.children] \
if self.children is not None else None
}
[docs] def load_state_dict(self, state_dict)->None:
self.trainables = state_dict['trainables']
c, cs = self.children, state_dict['children']
assert (c is None and cs is None) or \
(c is not None and cs is not None and len(c) == len(cs))
# TODO: when c and cs are both none, zip throws an error that the
# first argument should be iterable
if (c is None and cs is None):
return
for cx, csx in utils.zip_eq(c, cs):
if cx is not None and csx is not None:
cx.load_state_dict(csx)
[docs]class EdgeDesc:
"""Edge description between two nodes in the cell
"""
def __init__(self, op_desc:OpDesc, input_ids:List[int])->None:
assert op_desc.in_len == len(input_ids)
self.op_desc = op_desc
self.input_ids = input_ids
[docs] def clone(self, conv_params:Optional[ConvMacroParams], clear_trainables:bool)\
->'EdgeDesc':
# edge cloning is same as deep copy except that we do it through
# constructor for future proofing any additional future rules and
# that we allow overiding conv_params and clearing weights
e = EdgeDesc(self.op_desc.clone(), self.input_ids)
# op_desc should have params set from cloning. If no override supplied
# then don't change it
if conv_params is not None:
e.op_desc.params['conv'] = conv_params
if clear_trainables:
e.op_desc.clear_trainables()
return e
[docs] def clear_trainables(self)->None:
self.op_desc.clear_trainables()
[docs] def state_dict(self)->dict:
return {'op_desc': self.op_desc.state_dict()}
[docs] def load_state_dict(self, state_dict)->None:
self.op_desc.load_state_dict(state_dict['op_desc'])
[docs]class NodeDesc:
def __init__(self, edges:List[EdgeDesc], conv_params:ConvMacroParams) -> None:
self.edges = edges
self.conv_params = conv_params
[docs] def clone(self):
# don't override conv_params or reset learned weights
# node cloning is currently equivalent to deep copy
return NodeDesc(edges=[e.clone(conv_params=None, clear_trainables=False)
for e in self.edges], conv_params=self.conv_params)
[docs] def clear_trainables(self)->None:
for edge in self.edges:
edge.clear_trainables()
[docs] def state_dict(self)->dict:
return { 'edges': [e.state_dict() for e in self.edges] }
[docs] def load_state_dict(self, state_dict)->None:
for e, es in zip(self.edges, state_dict['edges']):
e.load_state_dict(es)
[docs]class AuxTowerDesc:
def __init__(self, ch_in:int, n_classes:int, stride:int) -> None:
self.ch_in = ch_in
self.n_classes = n_classes
self.stride = stride
[docs]class CellType(Enum):
Regular = 'regular'
Reduction = 'reduction'
[docs]class CellDesc:
def __init__(self, id:int, cell_type:CellType, conf_cell:Config,
stems:List[OpDesc], stem_shapes:TensorShapes,
nodes:List[NodeDesc], node_shapes: TensorShapes,
post_op:OpDesc, out_shape:TensorShape, trainables_from:int)->None:
self.cell_type = cell_type
self.id = id
self.conf_cell = conf_cell
self.stems = stems
self.stem_shapes = stem_shapes
self.out_shape = out_shape
self.trainables_from = trainables_from
self.reset_nodes(nodes, node_shapes, post_op, out_shape)
[docs] def clone(self, id:int)->'CellDesc':
c = copy.deepcopy(self) # note that trainables_from is also cloned
c.id = id
return c
[docs] def clear_trainables(self)->None:
for stem in self.stems:
stem.clear_trainables()
for node in self._nodes:
node.clear_trainables()
self.post_op.clear_trainables()
[docs] def state_dict(self)->dict:
return {
'id': self.id,
'cell_type': self.cell_type,
'stems': [s.state_dict() for s in self.stems],
'stem_shapes': self.stem_shapes,
'nodes': [n.state_dict() for n in self.nodes()],
'node_shapes': self.node_shapes,
'post_op': self.post_op.state_dict(),
'out_shape': self.out_shape
}
[docs] def load_state_dict(self, state_dict)->None:
assert self.id == state_dict['id']
assert self.cell_type == state_dict['cell_type']
for s, ss in utils.zip_eq(self.stems, state_dict['stems']):
s.load_state_dict(ss)
self.stem_shapes = state_dict['stem_shapes']
for n, ns in utils.zip_eq(self.nodes(), state_dict['nodes']):
n.load_state_dict(ns)
self.node_shapes = state_dict['node_shapes']
self.post_op.load_state_dict(state_dict['post_op'])
self.out_shape = state_dict['out_shape']
[docs] def reset_nodes(self, nodes:List[NodeDesc], node_shapes:TensorShapes,
post_op:OpDesc, out_shape:TensorShape)->None:
self._nodes = nodes
self.node_shapes = node_shapes
self.post_op = post_op
self.out_shape = out_shape
[docs] def nodes(self)->List[NodeDesc]:
return self._nodes
[docs] def all_empty(self)->bool:
return len(self._nodes)==0 or all((len(n.edges)==0 for n in self._nodes))
[docs] def all_full(self)->bool:
return len(self._nodes)>0 and all((len(n.edges)>0 for n in self._nodes))
[docs]class ModelDesc:
def __init__(self, conf_model_desc:Config, model_stems:List[OpDesc], pool_op:OpDesc,
cell_descs:List[CellDesc], aux_tower_descs:List[Optional[AuxTowerDesc]],
logits_op:OpDesc)->None:
self.conf_model_desc = conf_model_desc
conf_dataset = conf_model_desc['dataset']
self.ds_ch:int = conf_dataset['channels']
self.n_classes:int = conf_dataset['n_classes']
self.params = conf_model_desc['params'].to_dict()
self.max_final_edges:int = conf_model_desc['max_final_edges']
self.model_stems, self.pool_op = model_stems, pool_op
self.logits_op = logits_op
self.reset_cells(cell_descs, aux_tower_descs)
[docs] def reset_cells(self, cell_descs:List[CellDesc],
aux_tower_descs:List[Optional[AuxTowerDesc]])->None:
assert len(cell_descs) == len(aux_tower_descs)
# every cell should have unique ID so we can tell where arch params are shared
assert len(set(c.id for c in cell_descs)) == len(cell_descs)
self._cell_descs = cell_descs
self.aux_tower_descs = aux_tower_descs
[docs] def clear_trainables(self)->None:
for stem in self.model_stems:
stem.clear_trainables()
for attr in ['pool_op', 'logits_op']:
op_desc:OpDesc = getattr(self, attr)
op_desc.clear_trainables()
for cell_desc in self._cell_descs:
cell_desc.clear_trainables()
[docs] def cell_descs(self)->List[CellDesc]:
return self._cell_descs
[docs] def cell_type_count(self, cell_type:CellType)->int:
return sum(1 for c in self._cell_descs if c.cell_type==cell_type)
[docs] def clone(self)->'ModelDesc':
return copy.deepcopy(self)
[docs] def has_aux_tower(self)->bool:
return any(self.aux_tower_descs)
[docs] def all_empty(self)->bool:
return len(self._cell_descs)==0 or \
all((c.all_empty() for c in self._cell_descs))
[docs] def all_full(self)->bool:
return len(self._cell_descs)>0 and \
all((c.all_full() for c in self._cell_descs))
[docs] def state_dict(self)->dict:
return {
'cell_descs': [c.state_dict() for c in self.cell_descs()],
'model_stems': [stem.state_dict() for stem in self.model_stems],
'pool_op': self.pool_op.state_dict(),
'logits_op': self.logits_op.state_dict()
}
[docs] def load_state_dict(self, state_dict)->None:
for c, cs in utils.zip_eq(self.cell_descs(), state_dict['cell_descs']):
c.load_state_dict(cs)
for stem, state in utils.zip_eq(self.model_stems, state_dict['model_stems']):
stem.load_state_dict(state)
self.pool_op.load_state_dict(state_dict['pool_op'])
self.logits_op.load_state_dict(state_dict['logits_op'])
[docs] def save(self, filename:str, save_trainables=False)->Optional[str]:
if filename:
filename = utils.full_path(filename)
if save_trainables:
state_dict = self.state_dict()
pt_filepath = ModelDesc._pt_filepath(filename)
torch.save(state_dict, pt_filepath)
# save yaml
cloned = self.clone()
cloned.clear_trainables()
utils.write_string(filename, yaml.dump(cloned))
return filename
@staticmethod
def _pt_filepath(desc_filepath:str)->str:
# change file extension
return str(pathlib.Path(desc_filepath).with_suffix('.pth'))
[docs] @staticmethod
def load(filename:str, load_trainables=False)->'ModelDesc':
filename = utils.full_path(filename)
if not filename or not os.path.exists(filename):
raise RuntimeError("Model description file is not found."
"Typically this file should be generated from the search."
"Please copy this file to '{}'".format(filename))
logger.info({'final_desc_filename': filename})
with open(filename, 'r') as f:
model_desc = yaml.load(f, Loader=yaml.Loader)
if load_trainables:
# look for pth file that should have pytorch parameters state_dict
pt_filepath = ModelDesc._pt_filepath(filename)
if os.path.exists(pt_filepath):
state_dict = torch.load(pt_filepath, map_location=torch.device('cpu'))
model_desc.load_state_dict(state_dict)
# else no need to restore weights
return model_desc