Source code for archai.discrete_search.api.model_evaluator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import abstractmethod
from typing import List, Optional

from overrides import EnforceOverrides

from archai.api.dataset_provider import DatasetProvider
from archai.discrete_search.api.archai_model import ArchaiModel


[docs]class ModelEvaluator(EnforceOverrides): """Abstract class for synchronous model evaluators. Evaluators are general-use classes used to evaluate architectures in given criteria (task performance, speed, size, etc.). Subclasses of `ModelEvaluator` are expected to implement `ModelEvaluator.evaluate` Synchronous evaluators are computed by search algorithms sequentially. For parallel / async. execution, please refer to `archai.api.AsyncModelEvaluator`. For a list of built-in evaluators, please check `archai.discrete_search.evaluators`. Examples: >>> class MyValTaskAccuracy(ModelEvaluator): >>> def __init__(self, dataset: DatasetProvider, batch_size: int = 32): >>> self.dataset = dataset >>> self.batch_size = batch_size >>> >>> @overrides >>> def get_name(self) -> str: >>> return f'MyValTaskAccuracy_on_{self.dataset.get_name()}}' >>> >>> @overrides >>> def evaluate(self, model: ArchaiModel, budget: Optional[float] = None): >>> _, val_data = self.dataset.get_train_val_datasets() >>> val_dl = torch.utils.data.DataLoader(val_data, batch_size=self.batch_size) >>> >>> with torch.no_grad(): >>> labels = np.concatenate([y for _, y in val_dl], axis=0) >>> preds = np.concatenate( >>> [model.arch(x).cpu().numpy() for x, _ in val_dl], >>> axis=0 >>> ) >>> >>> return np.mean(labels == preds) >>> >>> class NumberOfModules(ModelEvaluator): >>> @overrides >>> def evaluate(self, model: ArchaiModel, budget: Optional[float] = None): >>> return len(list(model.arch.modules())) """
[docs] @abstractmethod def evaluate(self, arch: ArchaiModel, budget: Optional[float] = None) -> float: """Evaluate an `ArchaiModel` instance, optionally using a budget value. Args: arch: Model to be evaluated. dataset: A dataset provider object. budget: A budget multiplier value, used by search algorithms like `SuccessiveHalving` to specify how much compute should be spent in this evaluation. In order to use this type of search algorithm, the implementation of `evaluate()` must use the passed `budget` value accordingly. Returns: Evaluation result. """ pass
[docs]class AsyncModelEvaluator(EnforceOverrides): """Abstract class for asynchronous model evaluators. Evaluators are general-use classes used to evaluate architectures in given criteria (task performance, speed, size, etc.). Unlike `archai.api.ModelEvaluator`, `AsyncModelEvaluator` evaluates models in asynchronous fashion, by sending evaluation jobs to a queue and fetching the results later. Subclasses of `AsyncModelEvaluator` are expected to implement `AsyncModelEvaluator.send(arch: ArchaiModel, budget: Optional[float])` and `AsyncModelEvaluator.fetch_all()`. `AsyncModelEvaluator.send` is a non-blocking call that schedules an evaluation job for a given (model, budget) triplet. `AsyncModelEvaluator.fetch_all` is a blocking call that waits and gathers the results from current evaluation jobs and cleans the job queue. For a list of built-in evaluators, please check `archai.discrete_search.evaluators`. >>> my_obj = MyAsyncObj(dataset) # My AsyncModelEvaluator subclass >>> >>> # Non blocking calls >>> my_obj.send(model_1, budget=None) >>> my_obj.send(model_2, budget=None) >>> my_obj.send(model_3, budget=None) >>> >>> # Blocking call >>> eval_results = my_obj.fetch_all() >>> assert len(eval_results) == 3 >>> >>> # Job queue is reset after `fetch_call` method >>> my_obj.send(model_4, budget=None) >>> assert len(my_obj.fetch_all()) == 1 """
[docs] @abstractmethod def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None: """Send an evaluation job for a given (model, budget) triplet. Args: arch: Model to be evaluated. dataset: A dataset provider object. budget: A budget multiplier value, used by search algorithms like `SuccessiveHalving` to specify how much compute should be spent in this evaluation. In order to use this type of search algorithm, the implementation of `send()` must use the passed `budget` value accordingly. """ pass
[docs] @abstractmethod def fetch_all(self) -> List[Optional[float]]: """Fetch all evaluation results from the job queue. Returns: List of evaluation results. Each result is a `float` or `None` if evaluation job failed. """ pass