Source code for archai.supergraph.algos.nasbench101.model_matrix
import copy
from typing import List
import numpy as np
[docs]def prune(model_matrix:np.ndarray, vertex_ops:List[str]):
"""Prune the extraneous parts of the graph.
General procedure:
1) Remove parts of graph not connected to input.
2) Remove parts of graph not connected to output.
3) Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
shape = np.shape(model_matrix)
num_vertices = shape[0]
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError('model_matrix must be square')
if shape[0] != len(vertex_ops):
raise ValueError('length of vertex_ops must match model_matrix dimensions')
if not _is_upper_triangular(model_matrix):
raise ValueError('model_matrix must be upper triangular')
# DFS forward from input
visited_from_input = set([0])
frontier = [0]
while frontier:
top = frontier.pop()
for v in range(top + 1, num_vertices):
if model_matrix[top, v] and v not in visited_from_input:
visited_from_input.add(v)
frontier.append(v)
# DFS backward from output
visited_from_output = set([num_vertices - 1])
frontier = [num_vertices - 1]
while frontier:
top = frontier.pop()
for v in range(0, top):
if model_matrix[v, top] and v not in visited_from_output:
visited_from_output.add(v)
frontier.append(v)
# Any vertex that isn't connected to both input and output is extraneous to
# the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
# If the non-extraneous graph is less than 2 vertices, the input is not
# connected to the output and the spec is invalid.
if len(extraneous) > num_vertices - 2:
raise RuntimeError(f'Cannot build model because there are {extraneous} vertices which are larger than total vertices {num_vertices}-2')
model_matrix = copy.deepcopy(model_matrix)
model_matrix = np.delete(model_matrix, list(extraneous), axis=0)
model_matrix = np.delete(model_matrix, list(extraneous), axis=1)
vertex_ops = copy.deepcopy(vertex_ops)
for index in sorted(extraneous, reverse=True):
del vertex_ops[index]
return model_matrix, vertex_ops
def _is_upper_triangular(model_matrix:np.ndarray):
# TODO: just use np.allclose(mat, np.triu(mat))
"""True if matrix is 0 on diagonal and below."""
for src in range(np.shape(model_matrix)[0]):
for dst in range(0, src + 1):
if model_matrix[src, dst] != 0:
return False
return True