Source code for archai.supergraph.algos.nasbench101.model_builder

import logging
from typing import List

import torch
from torch import nn

from archai.common import ml_utils
from archai.supergraph.algos.nasbench101.model import Network
from archai.supergraph.algos.nasbench101.model_spec import ModelSpec

EXAMPLE_VERTEX_OPS = ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 'output']

EXAMPLE_DESC_MATRIX = [[0, 1, 1, 1, 0, 1, 0],
                        [0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 1, 0, 0],
                        [0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 0]]

[docs]def build(desc_matrix:List[List[int]], vertex_ops:List[str], device=None, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module: model_spec = ModelSpec(desc_matrix, vertex_ops) model = Network(model_spec, stem_out_channels, num_stacks, num_modules_per_stack, num_labels) logging.info(f'Model parameters: {ml_utils.param_size(model)}') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device model.to(device) return model