# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
from pathlib import Path
from typing import List, Optional
from overrides import overrides
from tqdm import tqdm
from archai.common.ordered_dict_logger import OrderedDictLogger
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.search_results import SearchResults
from archai.discrete_search.api.search_space import EvolutionarySearchSpace
from archai.discrete_search.api.searcher import Searcher
logger = OrderedDictLogger(source=__name__)
[docs]class EvolutionParetoSearch(Searcher):
"""Evolutionary multi-objective search algorithm that greedily
evolves Pareto frontier models.
It starts from an evaluated random subset of models. In each iteration, the algorithm
evaluates new subset of models generated from mutations (`mutation_per_parent`) and
crossovers (`num_crossovers`) of the current pareto frontier, and a new random subset
of models (`num_random_mix`). The process is repeated until `num_iters` is reached.
def __init__(
search_space: EvolutionarySearchSpace,
search_objectives: SearchObjectives,
output_dir: str,
num_iters: Optional[int] = 10,
init_num_models: Optional[int] = 10,
initial_population_paths: Optional[List[str]] = None,
num_random_mix: Optional[int] = 5,
max_unseen_population: Optional[int] = 100,
mutations_per_parent: Optional[int] = 1,
num_crossovers: Optional[int] = 5,
clear_evaluated_models: bool = True,
save_pareto_model_weights: bool = True,
seed: Optional[int] = 1,
"""Initialize the evolutionary search algorithm.
search_space: Discrete search space compatible with evolutionary algorithms.
search_objectives: Search objectives.
output_dir: Output directory.
num_iters: Number of iterations.
init_num_models: Number of initial models to evaluate.
initial_population_paths: List of paths to the initial population of models.
If `None`, `init_num_models` random models are used.
num_random_mix: Number of random models to mix with the population in each iteration.
max_unseen_population: Maximum number of unseen models to evaluate in each iteration.
mutations_per_parent: Number of distinct mutations generated for each Pareto frontier member.
num_crossovers: Total number of crossovers generated per iteration.
clear_evaluated_models: Optimizes memory usage by clearing the architecture
of `ArchaiModel` after each iteration. Defaults to True
save_pareto_model_weights: If `True`, saves the weights of the pareto models. Defaults to True
seed: Random seed.
super(EvolutionParetoSearch, self).__init__()
assert isinstance(
search_space, EvolutionarySearchSpace
), f"{str(search_space.__class__)} is not compatible with {str(self.__class__)}"
self.iter_num = 0
self.search_space = search_space
self.so = search_objectives
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True, parents=True)
# Algorithm settings
self.num_iters = num_iters
self.init_num_models = init_num_models
self.initial_population_paths = initial_population_paths
self.num_random_mix = num_random_mix
self.max_unseen_population = max_unseen_population
self.mutations_per_parent = mutations_per_parent
self.num_crossovers = num_crossovers
# Utils
self.clear_evaluated_models = clear_evaluated_models
self.save_pareto_model_weights = save_pareto_model_weights
self.search_state = SearchResults(search_space, self.so)
self.seed = seed
self.rng = random.Random(seed)
self.seen_archs = set()
self.num_sampled_archs = 0
assert self.init_num_models > 0
assert self.num_iters > 0
assert self.num_random_mix > 0
assert self.max_unseen_population > 0
[docs] def sample_models(self, num_models: int, patience: Optional[int] = 5) -> List[ArchaiModel]:
"""Sample models from the search space.
num_models: Number of models to sample.
patience: Number of tries to sample a valid model.
List of sampled models.
nb_tries, valid_sample = 0, []
while len(valid_sample) < num_models and nb_tries < patience:
sample = [self.search_space.random_sample() for _ in range(num_models)]
_, valid_indices = self.so.validate_constraints(sample)
valid_sample += [sample[i] for i in valid_indices]
nb_tries += 1
return valid_sample[:num_models]
[docs] def mutate_parents(
self, parents: List[ArchaiModel], mutations_per_parent: Optional[int] = 1, patience: Optional[int] = 20
) -> List[ArchaiModel]:
"""Mutate parents to generate new models.
parents: List of parent models.
mutations_per_parent: Number of mutations to apply to each parent.
patience: Number of tries to sample a valid model.
List of mutated models.
mutations = {}
for p in tqdm(parents, desc="Mutating parents"):
candidates = {}
nb_tries = 0
while len(candidates) < mutations_per_parent and nb_tries < patience:
nb_tries += 1
mutated_model = self.search_space.mutate(p)
mutated_model.metadata["parent"] = p.archid
if not self.so.is_model_valid(mutated_model):
if mutated_model.archid not in self.seen_archs:
mutated_model.metadata["generation"] = self.iter_num
candidates[mutated_model.archid] = mutated_model
return list(mutations.values())
[docs] def crossover_parents(
self, parents: List[ArchaiModel], num_crossovers: Optional[int] = 1, patience: Optional[int] = 30
) -> List[ArchaiModel]:
"""Crossover parents to generate new models.
parents: List of parent models.
num_crossovers: Number of crossovers to apply.
patience: Number of tries to sample a valid model.
List of crossovered models.
# Randomly samples k distinct pairs from `parents`
children, children_ids = [], set()
if len(parents) >= 2:
pairs = [self.rng.sample(parents, 2) for _ in range(num_crossovers)]
for p1, p2 in pairs:
child = self.search_space.crossover([p1, p2])
nb_tries = 0
while not self.so.is_model_valid(child) and nb_tries < patience:
child = self.search_space.crossover([p1, p2])
nb_tries += 1
if child and self.so.is_model_valid(child):
if child.archid not in children_ids and child.archid not in self.seen_archs:
child.metadata["generation"] = self.iter_num
child.metadata["parents"] = f"{p1.archid},{p2.archid}"
return children
[docs] def on_calc_task_accuracy_end(self, current_pop: List[ArchaiModel]) -> None:
"""Callback function called right after calc_task_accuracy()."""
[docs] def on_search_iteration_start(self, current_pop: List[ArchaiModel]) -> None:
"""Callback function called right before each search iteration."""
[docs] def select_next_population(self, current_pop: List[ArchaiModel]) -> List[ArchaiModel]:
"""Select the next population from the current population
current_pop: Current population.
Next population.
return current_pop[: self.max_unseen_population]
[docs] @overrides
def search(self) -> SearchResults:
self.iter_num = 0
if self.initial_population_paths:
logger.info(f"Loading initial population from {len(self.initial_population_paths)} architectures ...")
unseen_pop = [self.search_space.load_arch(path) for path in self.initial_population_paths]
logger.info(f"Using {self.init_num_models} random architectures as the initial population ...")
unseen_pop = self.sample_models(self.init_num_models)
self.all_pop = unseen_pop
for i in range(self.num_iters):
self.iter_num = i + 1
logger.info(f"Iteration {i+1}/{self.num_iters}")
# Calculates objectives
logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...")
results = self.so.eval_all_objs(unseen_pop)
if len(results) == 0:
raise Exception("Search is finding no valid models")
# Mutation and crossover info
"parent": [p.metadata.get("parent", None) for p in unseen_pop],
"parents": [p.metadata.get("parents", None) for p in unseen_pop],
# Records evaluated archs to avoid computing the same architecture twice
self.seen_archs.update([m.archid for m in unseen_pop])
# update the pareto frontier
logger.info("Updating Pareto frontier ...")
pareto = self.search_state.get_pareto_frontier()["models"]
logger.info(f"Found {len(pareto)} members.")
# Saves search iteration results
# NOTE: There is a dependency on these file naming schemas on archai.common.notebook_helper
self.search_state.save_search_state(str(self.output_dir / f"search_state_{self.iter_num}.csv"))
str(self.output_dir / f"pareto_models_iter_{self.iter_num}"),
# Optimizes memory usage by clearing architectures from memory
if self.clear_evaluated_models:
logger.info("Optimzing memory usage ...")
[model.clear() for model in unseen_pop]
parents = pareto
logger.info(f"Choosing {len(parents)} parents ...")
# mutate random 'k' subsets of the parents
# while ensuring the mutations fall within
# desired constraint limits
mutated = self.mutate_parents(parents, self.mutations_per_parent)
logger.info(f"Mutation: {len(mutated)} new models.")
# crossover random 'k' subsets of the parents
# while ensuring the mutations fall within
# desired constraint limits
crossovered = self.crossover_parents(parents, self.num_crossovers)
logger.info(f"Crossover: {len(crossovered)} new models.")
# sample some random samples to add to the parent mix
# to mitigage local minima
rand_mix = self.sample_models(self.num_random_mix)
unseen_pop = crossovered + mutated + rand_mix
# shuffle before we pick a smaller population for the next stage
logger.info(f"Total unseen population: {len(unseen_pop)}.")
unseen_pop = self.select_next_population(unseen_pop)
logger.info(f"Total unseen population after `max_unseen_population` restriction: {len(unseen_pop)}.")
# update the set of architectures ever visited
return self.search_state