Source code for archai.discrete_search.evaluators.functional

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Optional

from overrides import overrides

from archai.api.dataset_provider import DatasetProvider
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator


[docs]class EvaluationFunction(ModelEvaluator): """Custom function evaluator. This evaluator is used to wrap a custom evaluation function. """ def __init__(self, evaluation_fn: Callable) -> None: """Initialize the evaluator. Args: evaluation_fn: Evaluation function that receives the parameters (model: ArchaiModel, dataloader: torch.utils.data.Dataloader, budget: float) and outputs a float. """ self.evaluation_fn = evaluation_fn
[docs] @overrides def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float: return self.evaluation_fn(model, budget)