# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from abc import ABC
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import validator
from olive.common.config_utils import ConfigBase, validate_config
from olive.evaluator.metric import MetricResult
from olive.strategy.search_algorithm import REGISTRY, SearchAlgorithm
from olive.strategy.search_parameter import SearchParameter
from olive.strategy.search_results import SearchResults
logger = logging.getLogger(__name__)
_VALID_EXECUTION_ORDERS = ["joint", "pass-by-pass"]
[docs]class SearchStrategyConfig(ConfigBase):
execution_order: str
search_algorithm: str
search_algorithm_config: ConfigBase = None
output_model_num: int = None
stop_when_goals_met: bool = False
max_iter: int = None
max_time: int = None
@validator("execution_order", pre=True)
def _validate_execution_order(cls, v):
if v not in _VALID_EXECUTION_ORDERS:
raise ValueError(f"Unknown execution order: {v}")
return v
@validator("search_algorithm", pre=True)
def _validate_search_algorithm(cls, v):
if v not in REGISTRY:
raise ValueError(f"Unknown search algorithm: {v}")
return v
@validator("search_algorithm_config", pre=True, always=True)
def _validate_search_algorithm_config(cls, v, values):
if "search_algorithm" not in values:
raise ValueError("Invalid search_algorithm")
config_class = REGISTRY[values["search_algorithm"]].get_config_class()
return validate_config(v, ConfigBase, config_class)
@validator("stop_when_goals_met", "max_iter", "max_time", pre=True)
def _validate_stop_when_goals_met(cls, v, values, field):
if "execution_order" not in values:
raise ValueError("Invalid execution_order")
if v and values["execution_order"] != "joint":
logger.info(f"{field.name} is only supported for joint execution order. Ignoring...")
return field.default
return v
[docs]class SearchStrategy(ABC):
"""
Search strategy
"""
def __init__(self, config: Union[Dict[str, Any], SearchStrategyConfig]):
self._config = validate_config(config, SearchStrategyConfig)
self._initialized = False
self.exit_criteria_met = False
def initialize(
self,
search_spaces_list: List[Tuple[str, Dict[str, SearchParameter]]],
init_model_id: str,
objective_dict: Dict[str, dict],
):
"""
Initialize the search strategy.
search_spaces_list: list of tuples of format (search_space_name, {param_name: SearchParameter})
objective_dict: dictionary of format {objective_name: {"higher_is_better": bool, "goal": float}}
"""
self._objective_dict = objective_dict
# search spaces
self._spaces_order = [search_space[0] for search_space in search_spaces_list]
self._spaces_dict = {search_space[0]: search_space[1] for search_space in search_spaces_list}
# search space dictionaries for pass are grouped based on execution_order
self._spaces_groups = self._group_search_spaces(self._spaces_order)
self._done_spaces_groups = []
self._active_spaces_group = None
# state
self._searchers: Dict[Any, SearchAlgorithm] = {}
self._search_results: Dict[Any, SearchResults] = {}
self._init_model_ids: Dict[Any, str] = {}
self._best_search_points = {}
# initialize the first search space
self._next_search_group(init_model_id)
self._initialized = True
def _group_search_spaces(self, search_space_names: List[str]):
"""
Group search spaces based on execution order.
"""
# joint: all passes grouped together
# pass-by-pass: each pass is a separate group
if self._config.execution_order == "joint":
search_spaces_groups = [search_space_names]
elif self._config.execution_order == "pass-by-pass":
search_spaces_groups = [[search_space_name] for search_space_name in search_space_names]
else:
raise ValueError(f"Unknown execution order: {self._config.execution_order}")
return search_spaces_groups
def _next_search_group(self, init_model_id: Optional[str] = None) -> Optional[SearchAlgorithm]:
"""
Get the next search space group and initialize the search algorithm.
"""
# TODO: organize the state better and make execution order more flexible using a graph
if self._active_spaces_group is not None:
self._done_spaces_groups.append(self._active_spaces_group)
# legacy, will update once search results has info function
sorted_model_ids, sorted_search_points, sorted_results = self._search_results[
tuple(self._active_spaces_group)
].sort_search_points(apply_goals=True)
if sorted_model_ids is None:
logger.warning(
f"No models in this search group {self._active_spaces_group} met the goals. Sorting the models"
" without applying goals..."
)
sorted_model_ids, sorted_search_points, sorted_results = self._search_results[
tuple(self._active_spaces_group)
].sort_search_points(apply_goals=False)
# TODO: this is a hack to get the best search point for the current search space group
# it totally work for joint execution order, but not for pass-by-pass
if sorted_search_points and sorted_results:
best_search_point = (
sorted_search_points[0],
list(sorted_results[0].values()),
sorted_model_ids[0],
)
self._best_search_points[tuple(self._active_spaces_group)] = best_search_point
init_model_id = best_search_point[2][-1]
if len(self._spaces_groups) == 0:
self._active_spaces_group = None
return None
if init_model_id is None and self._active_spaces_group is None:
raise ValueError("init_model_id must be provided for the first search group")
if init_model_id is None and self._active_spaces_group is not None:
raise ValueError(
f"The previous search group {self._active_spaces_group} has no output models that were created and"
" evaluated successfully. Cannot continue."
)
self._active_spaces_group = self._spaces_groups.pop(0)
self._searchers[tuple(self._active_spaces_group)] = self._create_searcher(self._active_spaces_group)
self._search_results[tuple(self._active_spaces_group)] = SearchResults(self._objective_dict)
self._init_model_ids[tuple(self._active_spaces_group)] = init_model_id
return self._active_spaces_group
def _create_searcher(self, search_space_names: List[str]) -> SearchAlgorithm:
"""
Create a search algorithm.
"""
search_spaces_dict = {space_name: deepcopy(self._spaces_dict[space_name]) for space_name in search_space_names}
objectives = list(self._objective_dict.keys())
higher_is_betters = [self._objective_dict[objective]["higher_is_better"] for objective in objectives]
if self._config.search_algorithm in REGISTRY:
searcher = REGISTRY[self._config.search_algorithm](
search_spaces_dict, objectives, higher_is_betters, self._config.search_algorithm_config
)
else:
raise ValueError(f"Unknown search algorithm: {self._config.search_algorithm}")
return searcher
def next_step(self) -> Optional[Dict[str, Any]]:
"""
Get the next step in the search
"""
if not self._initialized:
raise ValueError("Search strategy is not initialized")
if self.exit_criteria_met:
self._next_search_group()
# if there is no active searcher, we are done
if self._active_spaces_group is None:
return None
# get the next search point from the active searcher
search_point = self._searchers[tuple(self._active_spaces_group)].suggest()
# if there are no more search points, move to the next search space group
if search_point is None:
self._next_search_group()
return self.next_step()
return {
"search_point": search_point,
"model_id": self._init_model_ids[tuple(self._active_spaces_group)],
"passes": [(space_name, search_point[space_name]) for space_name in self._active_spaces_group],
}
def record_feedback_signal(
self,
search_point: Dict[str, Dict[str, Any]],
signal: MetricResult,
model_ids: List[str],
should_prune: bool = False,
):
"""
Record the feedback signal for the given search point.
"""
if not self._initialized:
raise ValueError("Search strategy is not initialized")
self._search_results[tuple(self._active_spaces_group)].record(search_point, signal, model_ids)
self._searchers[tuple(self._active_spaces_group)].report(search_point, signal, should_prune)
def check_exit_criteria(self, iter_num, time_diff, metric_signal):
"""
Check if the olive search_strategy should exit.
"""
self.exit_criteria_met = False
if not self._config.stop_when_goals_met:
# stop early stopping when stop_when_goals_met is False, but still apply goals check without stopping
return
# early exit is not supported for pass-by-pass execution order currently
if self._config.execution_order == "pass-by-pass":
return
if self._config.max_iter is not None and iter_num > self._config.max_iter:
self.exit_criteria_met = True
return
if self._config.max_time is not None and time_diff > self._config.max_time:
self.exit_criteria_met = True
return
if metric_signal == {}:
return
self.exit_criteria_met = self._config.stop_when_goals_met and self._search_results[
tuple(self._active_spaces_group)
].check_goals(metric_signal)
def get_output_model_num(self):
return self._config.output_model_num