Source code for archai.discrete_search.algos.random_search

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import random
from pathlib import Path
from typing import List, Optional

from overrides import overrides

from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.search_results import SearchResults
from archai.discrete_search.api.search_space import DiscreteSearchSpace
from archai.discrete_search.api.searcher import Searcher

logger = OrderedDictLogger(source=__name__)


[docs]class RandomSearch(Searcher): """Random search algorithm. It evaluates random samples from the search space in each iteration until `num_iters` is reached. """ def __init__( self, search_space: DiscreteSearchSpace, search_objectives: SearchObjectives, output_dir: str, num_iters: Optional[int] = 10, samples_per_iter: Optional[int] = 10, clear_evaluated_models: Optional[bool] = True, save_pareto_model_weights: bool = True, seed: Optional[int] = 1, ): """Initialize the random search algorithm. Args: search_space: Discrete search space. search_objectives: Search objectives. output_dir: Output directory. num_iters: Number of iterations. samples_per_iter: Number of samples per iteration. clear_evaluated_models (bool, optional): Optimizes memory usage by clearing the architecture of `ArchaiModel` after each iteration. Defaults to True. save_pareto_model_weights: If `True`, saves the weights of the pareto models. Defaults to True. seed: Random seed. """ super(RandomSearch, self).__init__() assert isinstance( search_space, DiscreteSearchSpace ), f"{str(search_space.__class__)} is not compatible with {str(self.__class__)}" self.iter_num = 0 self.search_space = search_space self.so = search_objectives self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True, parents=True) # Algorithm settings self.num_iters = num_iters self.samples_per_iter = samples_per_iter # Utils self.clear_evaluated_models = clear_evaluated_models self.save_pareto_model_weights = save_pareto_model_weights self.search_state = SearchResults(search_space, self.so) self.seed = seed self.rng = random.Random(seed) self.seen_archs = set() self.num_sampled_archs = 0 assert self.samples_per_iter > 0 assert self.num_iters > 0
[docs] def sample_models(self, num_models: int, patience: Optional[int] = 5) -> List[ArchaiModel]: """Sample models from the search space. Args: num_models: Number of models to sample. patience: Number of tries to sample a valid model. Returns: List of sampled models. """ nb_tries, valid_sample = 0, [] while len(valid_sample) < num_models and nb_tries < patience: sample = [self.search_space.random_sample() for _ in range(num_models)] _, valid_indices = self.so.validate_constraints(sample) valid_sample += [sample[i] for i in valid_indices if sample[i].archid not in self.seen_archs] return valid_sample[:num_models]
[docs] @overrides def search(self) -> SearchResults: for i in range(self.num_iters): self.iter_num = i + 1 self.on_start_iteration(self.iter_num) logger.info(f"Iteration {i+1}/{self.num_iters}") logger.info(f"Sampling {self.samples_per_iter} random models ...") unseen_pop = self.sample_models(self.samples_per_iter) # Calculates objectives logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...") results = self.so.eval_all_objs(unseen_pop) self.search_state.add_iteration_results(unseen_pop, results) # Records evaluated archs to avoid computing the same architecture twice self.seen_archs.update([m.archid for m in unseen_pop]) # update the pareto frontier logger.info("Updating Pareto frontier ...") pareto = self.search_state.get_pareto_frontier()["models"] logger.info(f"Found {len(pareto)} members.") # Saves search iteration results self.search_state.save_search_state(str(self.output_dir / f"search_state_{self.iter_num}.csv")) self.search_state.save_pareto_frontier_models( str(self.output_dir / f"pareto_models_iter_{self.iter_num}"), save_weights=self.save_pareto_model_weights ) self.search_state.save_all_2d_pareto_evolution_plots(str(self.output_dir)) # Clears models from memory if needed if self.clear_evaluated_models: logger.info("Optimzing memory usage ...") [model.clear() for model in unseen_pop] return self.search_state