Source code for archai.discrete_search.algos.local_search
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
from pathlib import Path
from typing import List, Optional
from overrides import overrides
from tqdm import tqdm
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.search_results import SearchResults
from archai.discrete_search.api.search_space import EvolutionarySearchSpace
from archai.discrete_search.api.searcher import Searcher
logger = OrderedDictLogger(source=__name__)
[docs]class LocalSearch(Searcher):
def __init__(
self,
search_space: EvolutionarySearchSpace,
search_objectives: SearchObjectives,
output_dir: str,
num_iters: Optional[int] = 10,
init_num_models: Optional[int] = 10,
initial_population_paths: Optional[List[str]] = None,
mutations_per_parent: Optional[int] = 1,
clear_evaluated_models: bool = True,
save_pareto_model_weights: bool = True,
seed: Optional[int] = 1,
):
"""Local search algorithm. In each iteration, the algorithm generates a new population by
mutating the current Pareto frontier. The process is repeated until `num_iters` is reached.
Args:
search_space (EvolutionarySearchSpace): Discrete search space compatible with evolutionary algorithms
search_objectives (SearchObjectives): Search objectives
output_dir (str): Output directory
num_iters (int, optional): Number of search iterations. Defaults to 10.
init_num_models (int, optional): Number of initial models. Defaults to 10.
initial_population_paths (Optional[List[str]], optional): Paths to initial population.
If None, then `init_num_models` random models are used. Defaults to None.
mutations_per_parent (int, optional): Number of mutations per parent. Defaults to 1.
clear_evaluated_models (bool, optional): Optimizes memory usage by clearing the architecture
of `ArchaiModel` after each iteration. Defaults to True.
save_pareto_model_weights: If `True`, saves the weights of the pareto models.
seed (int, optional): Random seed. Defaults to 1.
"""
super(LocalSearch, self).__init__()
assert isinstance(
search_space, EvolutionarySearchSpace
), f"{str(search_space.__class__)} is not compatible with {str(self.__class__)}"
self.iter_num = 0
self.search_space = search_space
self.so = search_objectives
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True, parents=True)
# Algorithm settings
self.num_iters = num_iters
self.init_num_models = init_num_models
self.initial_population_paths = initial_population_paths
self.mutations_per_parent = mutations_per_parent
# Utils
self.clear_evaluated_models = clear_evaluated_models
self.save_pareto_model_weights = save_pareto_model_weights
self.search_state = SearchResults(search_space, self.so)
self.seed = seed
self.rng = random.Random(seed)
self.seen_archs = set()
self.num_sampled_archs = 0
assert self.init_num_models > 0
assert self.num_iters > 0
[docs] def sample_models(self, num_models: int, patience: Optional[int] = 5) -> List[ArchaiModel]:
"""Sample models from the search space.
Args:
num_models: Number of models to sample.
patience: Number of tries to sample a valid model.
Returns:
List of sampled models.
"""
nb_tries, valid_sample = 0, []
while len(valid_sample) < num_models and nb_tries < patience:
sample = [self.search_space.random_sample() for _ in range(num_models)]
_, valid_indices = self.so.validate_constraints(sample)
valid_sample += [sample[i] for i in valid_indices]
return valid_sample[:num_models]
[docs] def mutate_parents(
self, parents: List[ArchaiModel], mutations_per_parent: Optional[int] = 1, patience: Optional[int] = 20
) -> List[ArchaiModel]:
"""Mutate parents to generate new models.
Args:
parents: List of parent models.
mutations_per_parent: Number of mutations to apply to each parent.
patience: Number of tries to sample a valid model.
Returns:
List of mutated models.
"""
mutations = {}
for p in tqdm(parents, desc="Mutating parents"):
candidates = {}
nb_tries = 0
while len(candidates) < mutations_per_parent and nb_tries < patience:
mutated_model = self.search_space.mutate(p)
mutated_model.metadata["parent"] = p.archid
if not self.so.is_model_valid(mutated_model):
continue
if mutated_model.archid not in self.seen_archs:
mutated_model.metadata["generation"] = self.iter_num
candidates[mutated_model.archid] = mutated_model
nb_tries += 1
mutations.update(candidates)
return list(mutations.values())
[docs] @overrides
def search(self) -> SearchResults:
self.iter_num = 0
if self.initial_population_paths:
logger.info(f"Loading initial population from {len(self.initial_population_paths)} architectures ...")
unseen_pop = [self.search_space.load_arch(path) for path in self.initial_population_paths]
else:
logger.info(f"Using {self.init_num_models} random architectures as the initial population ...")
unseen_pop = self.sample_models(self.init_num_models)
self.all_pop = unseen_pop
for i in range(self.num_iters):
self.iter_num = i + 1
self.on_start_iteration(self.iter_num)
logger.info(f"Iteration {i+1}/{self.num_iters}")
if len(unseen_pop) == 0:
logger.info("No models to evaluate. Stopping search ...")
break
# Calculates objectives
logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...")
results = self.so.eval_all_objs(unseen_pop)
self.search_state.add_iteration_results(
unseen_pop,
results,
# Mutation info
extra_model_data={
"parent": [p.metadata.get("parent", None) for p in unseen_pop],
},
)
# Records evaluated archs to avoid computing the same architecture twice
self.seen_archs.update([m.archid for m in unseen_pop])
# update the pareto frontier
logger.info("Updating Pareto frontier ...")
pareto = self.search_state.get_pareto_frontier()["models"]
logger.info(f"Found {len(pareto)} members.")
# Saves search iteration results
self.search_state.save_search_state(str(self.output_dir / f"search_state_{self.iter_num}.csv"))
self.search_state.save_pareto_frontier_models(
str(self.output_dir / f"pareto_models_iter_{self.iter_num}"),
save_weights=self.save_pareto_model_weights
)
self.search_state.save_all_2d_pareto_evolution_plots(str(self.output_dir))
# Clears models from memory if needed
if self.clear_evaluated_models:
logger.info("Optimzing memory usage ...")
[model.clear() for model in unseen_pop]
# mutate random 'k' subsets of the parents
# while ensuring the mutations fall within
# desired constraint limits
unseen_pop = self.mutate_parents(pareto, self.mutations_per_parent)
logger.info(f"Mutation: {len(unseen_pop)} new models.")
# update the set of architectures ever visited
self.all_pop.extend(unseen_pop)
return self.search_state