Source code for archai.supergraph.algos.manual.manual_evaluater

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import importlib
import sys

from overrides import overrides
from torch import nn

from archai.common import ml_utils
from archai.common.config import Config
from archai.common.ordered_dict_logger import get_global_logger
from archai.supergraph.nas.evaluater import Evaluater
from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder

logger = get_global_logger()


[docs]class ManualEvaluater(Evaluater):
[docs] @overrides def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder, final_desc_filename=None, full_desc_filename=None)->nn.Module: # region conf vars dataset_name = conf_eval['loader']['dataset']['name'] # if explicitly passed in then don't get from conf if not final_desc_filename: final_desc_filename = conf_eval['final_desc_filename'] model_factory_spec = conf_eval['model_factory_spec'] # endregion assert model_factory_spec return self._model_from_factory(model_factory_spec, dataset_name)
def _model_from_factory(self, model_factory_spec:str, dataset_name:str)->Model: splitted = model_factory_spec.rsplit('.', 1) function_name = splitted[-1] if len(splitted) > 1: module_name = splitted[0] else: module_name = self._default_module_name(dataset_name, function_name) module = importlib.import_module(module_name) if module_name else sys.modules[__name__] function = getattr(module, function_name) model = function() logger.info({'model_factory':True, 'module_name': module_name, 'function_name': function_name, 'params': ml_utils.param_size(model)}) return model