Source code for archai.discrete_search.evaluators.benchmark.natsbench_tss

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import re
from typing import Any, Dict, Optional

import nats_bench
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.search_spaces.benchmark.natsbench_tss import (
    NatsbenchTssSearchSpace
)


[docs]class NatsbenchMetric(ModelEvaluator): """Evaluate a model using a metric from the NATS-Bench API.""" def __init__( self, search_space: NatsbenchTssSearchSpace, metric_name: str, epochs: Optional[int] = None, raise_not_found: Optional[bool] = True, more_info_kwargs: Optional[Dict[str, Any]] = None, cost_info_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the evaluator. Args: search_space: Search space to use. metric_name: Metric to use. See `nats_bench.api.NatsBenchAPI.get_more_info` for available metrics. epochs: Number of epochs to use. If None, uses the default number of epochs. raise_not_found: If True, raises an error if the architecture does not belong to the search space. more_info_kwargs: Additional arguments to pass to `nats_bench.api.NatsBenchAPI.get_more_info`. cost_info_kwargs: Additional arguments to pass to `nats_bench.api.NatsBenchAPI.get_cost_info`. """ assert isinstance( search_space, NatsbenchTssSearchSpace ), "This objective function only works with architectures from NatsbenchTssSearchSpace" self.search_space = search_space self.metric_name = metric_name self.epochs = epochs self.archid_pattern = re.compile("natsbench-tss-([0-9]+)") self.api = nats_bench.create(str(self.search_space.natsbench_location), "tss", fast_mode=True, verbose=False) self.raise_not_found = raise_not_found self.more_info_kwargs = more_info_kwargs or dict() self.cost_info_kwargs = cost_info_kwargs or dict() self.total_time_spent = 0
[docs] @overrides def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> Optional[float]: natsbench_id = self.archid_pattern.match(model.archid) budget = int(budget) if budget else budget if not natsbench_id: if self.raise_not_found: raise ValueError( f"Architecture {model.archid} does not belong to the NatsBench search space. " "Please refer to `archai.search_spaces.discrete.NatsbenchSearchSpace` to " "use the Natsbench search space." ) return None info = self.api.get_more_info( int(natsbench_id.group(1)), dataset=self.search_space.base_dataset, iepoch=budget or self.epochs, **self.more_info_kwargs, ) cost_info = self.api.get_cost_info( int(natsbench_id.group(1)), dataset=self.search_space.base_dataset, **self.cost_info_kwargs ) if self.metric_name in info: result = info[self.metric_name] self.total_time_spent += info["train-all-time"] + info["test-all-time"] elif self.metric_name in cost_info: result = info[self.metric_name] else: raise KeyError(f"`metric_name` {self.metric_name} not found. Available metrics = {str(list(info.keys()))}") return result