Source code for archai.supergraph.algos.divnas.divnas_cell

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import defaultdict
from typing import Dict, List

import numpy as np

import archai.supergraph.algos.divnas.analyse_activations as aa
from archai.supergraph.nas.cell import Cell
from archai.supergraph.nas.operations import Op, Zero


[docs]class Divnas_Cell(): ''' Wrapper cell class for divnas specific modifications ''' def __init__(self, cell:Cell): self._cell = cell self._collect_activations = False self._edgeoptype = None self._sigma = None self._counter = 0 self.node_covs:Dict[int, np.array] = {} self.node_num_to_node_op_to_cov_ind:Dict[int, Dict[Op, int]] = {}
[docs] def collect_activations(self, edgeoptype, sigma:float)->None: self._collect_activations = True self._edgeoptype = edgeoptype self._sigma = sigma # collect bookkeeping info for i, node in enumerate(self._cell.dag): node_op_to_cov_ind:Dict[Op, int] = {} counter = 0 for edge in node: for op, alpha in edge._op.ops(): if isinstance(op, Zero): continue node_op_to_cov_ind[op] = counter counter += 1 self.node_num_to_node_op_to_cov_ind[i] = node_op_to_cov_ind # go through all edges in the DAG and if they are of edgeoptype # type then set them to collect activations for i, node in enumerate(self._cell.dag): # initialize the covariance matrix for this node num_ops = 0 for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype: num_ops += edge._op.num_primitive_ops - 1 edge._op.collect_activations = True self.node_covs[id(node)] = np.zeros((num_ops, num_ops))
[docs] def update_covs(self): assert self._collect_activations for _, node in enumerate(self._cell.dag): # TODO: convert to explicit ordering all_activs = [] for j, edge in enumerate(node): if type(edge._op) == self._edgeoptype: activs = edge._op.activations all_activs.append(activs) # update covariance matrix activs_converted = self._convert_activations(all_activs) new_cov = aa.compute_rbf_kernel_covariance(activs_converted, sigma=self._sigma) updated_cov = (self._counter * self.node_covs[id(node)] + new_cov) / (self._counter + 1) self.node_covs[id(node)] = updated_cov
[docs] def clear_collect_activations(self): for _, node in enumerate(self._cell.dag): for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype: edge._op.collect_activations = False self._collect_activations = False self._edgeoptype = None self._sigma = None self._node_covs = {}
def _convert_activations(self, all_activs:List[List[np.array]])->List[np.array]: ''' Converts to the format needed by covariance computing functions Input all_activs: List[List[np.array]]. Outer list len is num_edges. Inner list is of num_ops length. Each element in inner list is [batch_size, x, y, z] ''' num_ops = len(all_activs[0]) for activs in all_activs: assert num_ops == len(activs) all_edge_list = [] for edge in all_activs: obsv_dict = defaultdict(list) # assumption edge_np will be (num_ops, batch_size, x, y, z) edge_np = np.array(edge) for op in range(edge_np.shape[0]): for b in range(edge_np.shape[1]): feat = edge_np[op][b] feat = feat.flatten() obsv_dict[op].append(feat) feature_list = [*range(num_ops)] for key in obsv_dict.keys(): feat = np.array(obsv_dict[key]) feature_list[key] = feat all_edge_list.extend(feature_list) return all_edge_list