# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterable, Optional, Tuple
import numpy as np
import torch
from overrides import overrides
from torch import Tensor, nn
from archai.common import ml_utils
from archai.supergraph.nas.arch_module import ArchModule
from archai.supergraph.nas.cell import Cell
from archai.supergraph.nas.model_desc import AuxTowerDesc, CellDesc, ModelDesc
from archai.supergraph.nas.operations import DropPath_, Op
[docs]class Model(ArchModule):
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
super().__init__()
# some of these fields are public as finalizer needs access to them
self.desc = model_desc
# TODO: support any number of stems
assert len(model_desc.model_stems)==2, "Model compiler currently only supports 2 stems"
stem0_op = Op.create(model_desc.model_stems[0], affine=affine)
stem1_op = Op.create(model_desc.model_stems[1], affine=affine)
self.model_stems = nn.ModuleList((stem0_op, stem1_op))
self.cells = nn.ModuleList()
self._aux_towers = nn.ModuleList()
for i, (cell_desc, aux_tower_desc) in \
enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)):
self._build_cell(cell_desc, aux_tower_desc, droppath, affine)
# adaptive pooling output size to 1x1
self.pool_op = Op.create(model_desc.pool_op, affine=affine)
# since ch_p records last cell's output channels
# it indicates the input channel number
self.logits_op = Op.create(model_desc.logits_op, affine=affine)
def _build_cell(self, cell_desc:CellDesc,
aux_tower_desc:Optional[AuxTowerDesc],
droppath:bool, affine:bool)->None:
trainables_from = None if cell_desc.trainables_from==cell_desc.id \
else self.cells[cell_desc.trainables_from]
cell = Cell(cell_desc, affine=affine, droppath=droppath,
trainables_from=trainables_from)
self.cells.append(cell)
self._aux_towers.append(AuxTower(aux_tower_desc) \
if aux_tower_desc else None)
[docs] def summary(self)->dict:
all_arch_params = list(self.all_owned()
.param_by_kind(kind=None))
return {
'cell_count': len(self.cells),
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
'params': ml_utils.param_size(self),
'arch_params_len': len(all_arch_params),
'arch_params_numel': np.sum(a.numel() for a in all_arch_params),
'ops': np.sum(len(n.edges) for c in self.desc.cell_descs() for n in c.nodes()),
}
[docs] def ops(self)->Iterable[Op]:
for cell in self.cells:
for op in cell.ops():
yield op
[docs] @overrides
def forward(self, x)->Tuple[Tensor, Optional[Tensor]]:
#print(torch.cuda.memory_allocated()/1.0e6)
s0 = self.model_stems[0](x)
#print(torch.cuda.memory_allocated()/1.0e6)
s1 = self.model_stems[1](x)
#print(-1, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
logits_aux = None
for ci, (cell, aux_tower) in enumerate(zip(self.cells, self._aux_towers)):
#print(s0.shape, s1.shape, end='')
s0, s1 = s1, cell.forward(s0, s1)
#print(ci, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
# TODO: this mimics darts but won't work for multiple aux towers
if aux_tower is not None and self.training:
logits_aux = aux_tower(s1)
#print(ci, 'aux', logits_aux.shape)
# s1 is now the last cell's output
out = self.pool_op(s1)
logits = self.logits_op(out) # flatten
#print(-1, 'out', out.shape)
#print(-1, 'logits', logits.shape)
return logits, logits_aux
[docs] def device_type(self)->str:
return next(self.parameters()).device.type
[docs] def drop_path_prob(self, p:float):
"""Set drop path probability.
This will be called externally so any `DropPath_` modules get new probability.
Typically, every epoch we will reduce this probability.
"""
for module in self.modules():
if isinstance(module, DropPath_):
module.p = p
[docs]class AuxTower(nn.Module):
def __init__(self, aux_tower_desc:AuxTowerDesc):
"""assuming input size 14x14"""
# TODO: assert input size?
super().__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=aux_tower_desc.stride, padding=0, count_include_pad=False),
nn.Conv2d(aux_tower_desc.ch_in, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# TODO: This batchnorm was omitted in orginal implementation due to a typo.
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
)
self.logits_op = nn.Linear(768, aux_tower_desc.n_classes)
[docs] def forward(self, x:torch.Tensor):
x = self.features(x)
x = self.logits_op(x.view(x.size(0), -1))
return x