Source code for archai.supergraph.algos.divnas.divnas_finalizers

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, List

import torch
from overrides import overrides
from torch import nn

from archai.common.common import get_conf
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.algos.divnas.analyse_activations import compute_brute_force_sol
from archai.supergraph.algos.divnas.divnas_cell import Divnas_Cell
from archai.supergraph.algos.divnas.divop import DivOp
from archai.supergraph.datasets.data import get_data
from archai.supergraph.nas.cell import Cell
from archai.supergraph.nas.finalizers import Finalizers
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import CellDesc, EdgeDesc, ModelDesc, NodeDesc

logger = get_global_logger()


[docs]class DivnasFinalizers(Finalizers):
[docs] @overrides def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() conf_loader = conf['nas']['search']['loader'] data_loaders = get_data(conf_loader) assert data_loaders.train_dl is not None # wrap all cells in the model self._divnas_cells:Dict[int, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(data_loaders.train_dl): _, _ = model(x), None # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)
[docs] @overrides def finalize_cell(self, cell:Cell, cell_index:int, model_desc:ModelDesc, *args, **kwargs)->CellDesc: # first finalize each node, we will need to recreate node desc with final version max_final_edges = model_desc.max_final_edges node_descs:List[NodeDesc] = [] dcell = self._divnas_cells[id(cell)] assert len(cell.dag) == len(list(dcell.node_covs.values())) for i,node in enumerate(cell.dag): node_cov = dcell.node_covs[id(node)] node_desc = self.finalize_node(node, i, cell.desc.nodes()[i], max_final_edges, node_cov) node_descs.append(node_desc) # (optional) clear out all activation collection information dcell.clear_collect_activations() desc = cell.desc finalized = CellDesc( id = desc.id, cell_type=desc.cell_type, conf_cell=desc.conf_cell, stems=[cell.s0_op.finalize()[0], cell.s1_op.finalize()[0]], stem_shapes=desc.stem_shapes, nodes = node_descs, node_shapes=desc.node_shapes, post_op=cell.post_op.finalize()[0], out_shape=desc.out_shape, trainables_from = desc.trainables_from ) return finalized
[docs] @overrides def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, cov, *args, **kwargs)->NodeDesc: # node is a list of edges assert len(node) >= max_final_edges # covariance matrix shape must be square 2-D assert len(cov.shape) == 2 assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges # get total number of ops incoming to this node num_ops = sum([edge._op.num_valid_div_ops for edge in node]) # and collect some bookkeeping indices edge_num_and_op_ind = [] for j, edge in enumerate(node): if type(edge._op) == DivOp: for k in range(edge._op.num_valid_div_ops): edge_num_and_op_ind.append((j, k)) assert len(edge_num_and_op_ind) == num_ops # run brute force set selection algorithm max_subset, max_mi = compute_brute_force_sol(cov, max_final_edges) # convert the cov indices to edge descs selected_edges = [] for ind in max_subset: edge_ind, op_ind = edge_num_and_op_ind[ind] op_desc = node[edge_ind]._op.get_valid_op_desc(op_ind) new_edge = EdgeDesc(op_desc, node[edge_ind].input_ids) selected_edges.append(new_edge) # for edge in selected_edges: # self.finalize_edge(edge) return NodeDesc(selected_edges, node_desc.conv_params)