Source code for archai.supergraph.algos.divnas.divnas_rank_finalizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import torch
from overrides import overrides
from torch import nn

from archai.common.common import get_conf, get_expdir
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.algos.divnas.analyse_activations import compute_brute_force_sol
from archai.supergraph.algos.divnas.divnas_cell import Divnas_Cell
from archai.supergraph.algos.divnas.divop import DivOp
from archai.supergraph.datasets.data import get_data
from archai.supergraph.nas.cell import Cell
from archai.supergraph.nas.finalizers import Finalizers
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import CellDesc, EdgeDesc, ModelDesc, NodeDesc
from archai.supergraph.nas.operations import Zero
from archai.supergraph.utils.heatmap import heatmap

logger = get_global_logger()


[docs]class DivnasRankFinalizers(Finalizers):
[docs] @overrides def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader conf = get_conf() conf_loader = conf['nas']['search']['loader'] data_loaders = get_data(conf_loader) assert data_loaders.train_dl is not None # wrap all cells in the model self._divnas_cells: Dict[Cell, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[cell] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(data_loaders.train_dl): _, _ = model(x), None # update the node covariances in all cells for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)
[docs] @overrides def finalize_cell(self, cell:Cell, cell_index:int, model_desc:ModelDesc, *args, **kwargs)->CellDesc: # first finalize each node, we will need to recreate node desc with final version logger.info(f'cell id {cell.desc.id}') max_final_edges = model_desc.max_final_edges node_descs: List[NodeDesc] = [] dcell = self._divnas_cells[cell] assert len(cell.dag) == len(list(dcell.node_covs.values())) for i, node in enumerate(cell.dag): node_cov = dcell.node_covs[id(node)] logger.info(f'node {i}') node_desc = self.finalize_node(node, i, cell.desc.nodes()[i],max_final_edges, node_cov, cell, i) node_descs.append(node_desc) # (optional) clear out all activation collection information dcell.clear_collect_activations() desc = cell.desc finalized = CellDesc( id = desc.id, cell_type=desc.cell_type, conf_cell=desc.conf_cell, stems=[cell.s0_op.finalize()[0], cell.s1_op.finalize()[0]], stem_shapes=desc.stem_shapes, nodes = node_descs, node_shapes=desc.node_shapes, post_op=cell.post_op.finalize()[0], out_shape=desc.out_shape, trainables_from = desc.trainables_from ) return finalized
[docs] @overrides def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, cov:np.array, cell: Cell, node_id: int, *args, **kwargs)->NodeDesc: # node is a list of edges assert len(node) >= max_final_edges # covariance matrix shape must be square 2-D assert len(cov.shape) == 2 assert cov.shape[0] == cov.shape[1] # the number of primitive operators has to be greater # than equal to the maximum number of final edges # allowed assert cov.shape[0] >= max_final_edges # get the order and alpha of all ops other than 'none' in_ops = [(edge,op,alpha,i) for i, edge in enumerate(node) \ for op, alpha in edge._op.ops() if not isinstance(op, Zero)] assert len(in_ops) >= max_final_edges # order all the ops by alpha in_ops_sorted = sorted(in_ops, key=lambda in_op:in_op[2], reverse=True) # keep under consideration top half of the ops num_to_keep = max(max_final_edges, len(in_ops_sorted)//2) top_ops = in_ops_sorted[:num_to_keep] # get the covariance submatrix of the top ops only cov_inds = [] for edge, op, alpha, edge_num in top_ops: ind = self._divnas_cells[cell].node_num_to_node_op_to_cov_ind[node_id][op] cov_inds.append(ind) cov_top_ops = cov[np.ix_(cov_inds, cov_inds)] assert len(cov_inds) == len(top_ops) assert len(top_ops) >= max_final_edges assert cov_top_ops.shape[0] == cov_top_ops.shape[1] assert len(cov_top_ops.shape) == 2 # run brute force set selection algorithm # only on the top ops max_subset, max_mi = compute_brute_force_sol(cov_top_ops, max_final_edges) # note that elements of max_subset are indices into top_ops only selected_edges = [] for ind in max_subset: edge, op, alpha, edge_num = top_ops[ind] op_desc, _ = op.finalize() new_edge = EdgeDesc(op_desc, edge.input_ids) logger.info(f'selected edge: {edge_num}, op: {op_desc.name}') selected_edges.append(new_edge) # save diagnostic information to disk expdir = get_expdir() heatmap(cov_top_ops, fmt='.1g', cmap='coolwarm') savename = os.path.join( expdir, f'cell_{cell.desc.id}_node_{node_id}_cov.png') plt.savefig(savename) logger.info('') return NodeDesc(selected_edges, node_desc.conv_params)