Source code for archai.discrete_search.algos.successive_halving

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import random
from pathlib import Path
from typing import Optional

from overrides import overrides

from archai.api.dataset_provider import DatasetProvider
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.search_results import SearchResults
from archai.discrete_search.api.search_space import DiscreteSearchSpace
from archai.discrete_search.api.searcher import Searcher
from archai.discrete_search.utils.multi_objective import get_non_dominated_sorting

logger = OrderedDictLogger(source=__name__)


[docs]class SuccessiveHalvingSearch(Searcher): """Successive Halving algorithm""" def __init__( self, search_space: DiscreteSearchSpace, objectives: SearchObjectives, dataset_provider: DatasetProvider, output_dir: str, num_iters: Optional[int] = 10, init_num_models: Optional[int] = 10, init_budget: Optional[float] = 1.0, budget_multiplier: Optional[float] = 2.0, seed: Optional[int] = 1, ) -> None: """Initialize the Successive Halving. Args: search_space: Discrete search space. search_objectives: Search objectives. dataset_provider: Dataset provider. output_dir: Output directory. num_iters: Number of iterations. init_num_models: Number of initial models to evaluate. init_budget: Initial budget. budget_multiplier: Budget multiplier. seed: Random seed. """ super(SuccessiveHalvingSearch, self).__init__() assert isinstance(search_space, DiscreteSearchSpace) # Search parameters self.search_space = search_space self.objectives = objectives self.dataset_provider = dataset_provider self.output_dir = Path(output_dir) self.num_iters = num_iters self.init_num_models = init_num_models self.init_budget = init_budget self.budget_multiplier = budget_multiplier self.output_dir.mkdir(exist_ok=True) # Utils self.iter_num = 0 self.num_sampled_models = 0 self.seed = seed self.search_state = SearchResults(search_space, objectives) self.rng = random.Random(seed) self.output_dir.mkdir(exist_ok=True, parents=True)
[docs] @overrides def search(self) -> SearchResults: current_budget = self.init_budget population = [self.search_space.random_sample() for _ in range(self.init_num_models)] selected_models = population for i in range(self.num_iters): if len(selected_models) <= 1: logger.info(f"Search ended. Architecture selected: {selected_models[0].archid}") self.search_space.save_arch(selected_models[0], self.output_dir / "final_model") break self.on_start_iteration(i + 1) logger.info(f"Iteration {i+1}/{self.num_iters}") logger.info(f"Evaluating {len(selected_models)} models with budget {current_budget} ...") results = self.objectives.eval_all_objs( selected_models, budgets={obj_name: current_budget for obj_name in self.objectives.objectives}, ) # Logs results and saves iteration models self.search_state.add_iteration_results( selected_models, results, extra_model_data={"budget": [current_budget] * len(selected_models)} ) models_dir = self.output_dir / f"models_iter_{self.iter_num}" models_dir.mkdir(exist_ok=True) for model in selected_models: self.search_space.save_arch(model, str(models_dir / f"{model.archid}")) self.search_state.save_search_state(str(self.output_dir / f"search_state_{self.iter_num}.csv")) self.search_state.save_all_2d_pareto_evolution_plots(self.output_dir) # Keeps only the best `1/self.budget_multiplier` NDS frontiers logger.info("Choosing models for the next iteration ...") nds_frontiers = get_non_dominated_sorting(selected_models, results, self.objectives) nds_frontiers = nds_frontiers[: int(len(nds_frontiers) * 1 / self.budget_multiplier)] selected_models = [model for frontier in nds_frontiers for model in frontier["models"]] logger.info(f"Kept {len(selected_models)} models for next iteration.") # Update parameters for next iteration self.iter_num += 1 current_budget = current_budget * self.budget_multiplier return self.search_state