Source code for archai.discrete_search.evaluators.ray

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, List, Optional, Union

import ray
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import (
    AsyncModelEvaluator,
    ModelEvaluator,
)


def _wrap_metric_calculate(class_method) -> Callable:
    def _calculate(arch: ArchaiModel, budget: Optional[float] = None) -> Callable:
        return class_method(arch, budget)

    return _calculate


[docs]class RayParallelEvaluator(AsyncModelEvaluator): """Wraps a `ModelEvaluator` object into an `AsyncModelEvaluator` with parallel execution using Ray. `RayParallelEvaluator` expects a stateless objective function as input, meaning that any `ModelEvaluator.evaluate(arch, ...)` will not alter the state of `obj` or `arch` in any way. """ def __init__( self, obj: ModelEvaluator, timeout: Optional[float] = None, force_stop: Optional[bool] = False, **ray_kwargs ) -> None: """Initialize the evaluator. Args: obj: A `ModelEvaluator` object. timeout: Timeout for receiving results from Ray. If None, then Ray will wait indefinitely for results. If timeout is reached, then incomplete tasks are canceled and returned as None. force_stop: If incomplete tasks (within `timeout` seconds) should be force-killed. If set to `False`, Ray will just send a `KeyboardInterrupt` signal to the process. **ray_kwargs: Key-value arguments for ray.remote(), e.g: num_gpus, num_cpus, max_task_retries. """ assert isinstance(obj, ModelEvaluator) # Wraps metric.calculate as a standalone function. This only works with stateless metrics if ray_kwargs: self.compute_fn = ray.remote(**ray_kwargs)(_wrap_metric_calculate(obj.evaluate)) else: self.compute_fn = ray.remote(_wrap_metric_calculate(obj.evaluate)) self.timeout = timeout self.force_stop = force_stop self.object_refs = []
[docs] @overrides def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None: self.object_refs.append(self.compute_fn.remote(arch, budget))
[docs] @overrides def fetch_all(self) -> List[Union[float, None]]: results = [None] * len(self.object_refs) if not self.timeout: results = ray.get(self.object_refs, timeout=self.timeout) else: # Maps each object from the object_refs list to its index ref2idx = {ref: i for i, ref in enumerate(self.object_refs)} # Gets all results available within `self.timeout` seconds. complete_objs, incomplete_objs = ray.wait( self.object_refs, timeout=self.timeout, num_returns=len(self.object_refs) ) partial_results = ray.get(complete_objs) # Update results with the partial results fetched for ref, result in zip(complete_objs, partial_results): results[ref2idx[ref]] = result # Cancels incomplete jobs for incomplete_obj in incomplete_objs: ray.cancel(incomplete_obj, force=self.force_stop) # Resets metric state self.object_refs = [] return results