Source code for archai.supergraph.nas.vis_model_desc
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Tuple
from graphviz import Digraph
from archai.common.ordered_dict_logger import get_global_logger
from archai.common.utils import first_or_default
from archai.supergraph.nas.model_desc import CellDesc, CellType, ModelDesc
logger = get_global_logger()
[docs]def draw_model_desc(model_desc:ModelDesc, filepath:str=None, caption:str=None)\
->Tuple[Optional[Digraph],Optional[Digraph]]:
normal_cell_desc = first_or_default((c for c in model_desc.cell_descs() \
if c.cell_type == CellType.Regular), None)
reduced_cell_desc = first_or_default((c for c in model_desc.cell_descs() \
if c.cell_type == CellType.Reduction), None)
g_normal = draw_cell_desc(normal_cell_desc,
filepath+'-normal' if filepath else None,
caption) if normal_cell_desc is not None else None
g_reduct = draw_cell_desc(reduced_cell_desc,
filepath+'-reduced' if filepath else None,
caption) if reduced_cell_desc is not None else None
return g_normal, g_reduct
[docs]def draw_cell_desc(cell_desc:CellDesc, filepath:str=None, caption:str=None
)->Digraph:
""" make DAG plot and optionally save to filepath as .png """
edge_attr = {
'fontsize': '20',
'fontname': 'times'
}
node_attr = {
'style': 'filled',
'shape': 'rect',
'align': 'center',
'fontsize': '20',
'height': '0.5',
'width': '0.5',
'penwidth': '2',
'fontname': 'times'
}
g = Digraph(
format='png',
edge_attr=edge_attr,
node_attr=node_attr,
engine='dot')
g.body.extend(['rankdir=LR'])
# input nodes
# TODO: remove only two input node as assumption
g.node("c_{k-2}", fillcolor='darkseagreen2')
g.node("c_{k-1}", fillcolor='darkseagreen2')
# intermediate nodes
n_nodes = len(cell_desc.nodes())
for i in range(n_nodes):
g.node(str(i), fillcolor='lightblue')
for i, node in enumerate(cell_desc.nodes()):
for edge in node.edges:
op, js = edge.op_desc.name, edge.input_ids
for j in js:
if j == 0:
u = "c_{k-2}"
elif j == 1:
u = "c_{k-1}"
else:
u = str(j-2)
v = str(i)
g.edge(u, v, label=op, fillcolor="gray")
# output node
g.node("c_{k}", fillcolor='palegoldenrod')
for i in range(n_nodes):
g.edge(str(i), "c_{k}", fillcolor="gray")
# add image caption
if caption:
g.attr(label=caption, overlap='false', fontsize='20', fontname='times')
if filepath:
g.render(filepath, view=False)
logger.info(f'plot_filename: {filepath}')
return g