# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import ray
from overrides import overrides
from archai.api.dataset_provider import DatasetProvider
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import (
AsyncModelEvaluator,
ModelEvaluator,
)
from archai.discrete_search.api.search_space import DiscreteSearchSpace
from archai.common.file_utils import TemporaryFiles
def _ray_wrap_training_fn(training_fn) -> Callable:
def _stateful_training_fn(
arch: ArchaiModel, dataset: DatasetProvider, budget: float, training_state: Optional[Dict[str, Any]] = None
) -> Tuple[ArchaiModel, float, Dict[str, Any]]:
metric_result, training_state = training_fn(arch, dataset, budget, training_state)
return arch, metric_result, training_state
return _stateful_training_fn
[docs]class ProgressiveTraining(ModelEvaluator):
"""Progressive training evaluator."""
def __init__(self, search_space: DiscreteSearchSpace, dataset: DatasetProvider, training_fn: Callable) -> None:
"""Initialize the evaluator.
Args:
search_space: Search space.
training_fn: Training function.
"""
self.search_space = search_space
self.training_fn = training_fn
self.dataset = dataset
# Training state buffer (e.g optimizer state) for each architecture id
self.training_states = {}
[docs] @overrides
def evaluate(self, arch: ArchaiModel, budget: Optional[float] = None) -> float:
# Tries to retrieve previous training state
tr_state = self.training_states.get(arch.archid, None)
# Computes metric and updates training state
metric_result, updated_tr_state = self.training_fn(arch, self.dataset, budget, tr_state)
self.training_states[arch.archid] = updated_tr_state
return metric_result
[docs]class RayProgressiveTraining(AsyncModelEvaluator):
"""Progressive training evaluator using Ray."""
def __init__(
self,
search_space: DiscreteSearchSpace,
dataset: DatasetProvider,
training_fn: Callable,
timeout: Optional[float] = None,
force_stop: Optional[bool] = False,
**ray_kwargs
) -> None:
"""Initialize the evaluator.
Args:
search_space: Search space.
training_fn: Training function.
timeout: Timeout (seconds) for fetching results.
force_stop: If True, forces to stop all training jobs when fetching results.
"""
self.search_space = search_space
self.dataset = dataset
if ray_kwargs:
self.compute_fn = ray.remote(**ray_kwargs)(_ray_wrap_training_fn(training_fn))
else:
self.compute_fn = ray.remote(_ray_wrap_training_fn(training_fn))
self.timeout = timeout
self.force_stop = force_stop
# Buffer that stores original model references from `send` calls
# to update weights after training is complete
self.models = []
# Ray training job object refs
self.results_ref = []
# Training state buffer (e.g optimizer state) for each architecture id
self.training_states = {}
[docs] @overrides
def send(self, arch: ArchaiModel, budget: Optional[float] = None) -> None:
# Stores original model reference
self.models.append(arch)
current_tr_state = self.training_states.get(arch.archid, None)
self.results_ref.append(self.compute_fn.remote(arch, self.dataset, budget, current_tr_state))
[docs] @overrides
def fetch_all(self) -> List[Union[float, None]]:
results = [None] * len(self.results_ref)
# Fetchs training job results
if not self.timeout:
results = ray.get(self.results_ref, timeout=self.timeout)
else:
# Maps each object from the object_refs list to its index
ref2idx = {ref: i for i, ref in enumerate(self.results_ref)}
# Gets all results available within `self.timeout` seconds.
complete_objs, incomplete_objs = ray.wait(
self.results_ref, timeout=self.timeout, num_returns=len(self.results_ref)
)
partial_results = ray.get(complete_objs)
for ref, result in zip(complete_objs, partial_results):
results[ref2idx[ref]] = result
for incomplete_obj in incomplete_objs:
ray.cancel(incomplete_obj, force=self.force_stop)
# Gathers metrics and syncs local references
metric_results = []
for job_id, job_results in enumerate(results):
if job_results:
trained_model, job_metric, training_state = job_results
# Syncs model weights
# On windows you cannot open a named temporary file a second time.
temp_file_name = None
with TemporaryFiles() as tmp:
temp_file_name = tmp.get_temp_file()
self.search_space.save_model_weights(trained_model, temp_file_name)
self.search_space.load_model_weights(self.models[job_id], temp_file_name)
# Syncs training state
self.training_states[trained_model.archid] = training_state
metric_results.append(job_metric)
# Resets model and job buffers
self.models = []
self.results_ref = []
return metric_results