Source code for olive.engine

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
import logging
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from olive.cache import clean_cache, clean_evaluation_cache, clean_pass_run_cache, create_cache, get_cache_sub_dirs
from olive.common.config_utils import ConfigBase, validate_config
from olive.common.utils import hash_dict
from olive.evaluator.metric import Metric
from olive.evaluator.olive_evaluator import OliveEvaluator, OliveEvaluatorConfig
from olive.model import ModelConfig, OliveModel
from olive.passes.olive_pass import Pass
from olive.strategy.search_strategy import SearchStrategy, SearchStrategyConfig
from olive.systems.common import SystemType
from olive.systems.local import LocalSystem
from olive.systems.olive_system import OliveSystem
from olive.systems.system_config import SystemConfig

logger = logging.getLogger(__name__)

# pass search-point/config was pruned due invalid config or failed run
PRUNED_CONFIG = "pruned-config"


[docs]class EngineConfig(ConfigBase): search_strategy: SearchStrategyConfig = None host: SystemConfig = None target: SystemConfig = None model_io_config: Dict[str, List] = None evaluator: OliveEvaluatorConfig = None cache_dir: Union[Path, str] = ".olive-cache" clean_cache: bool = False clean_evaluation_cache: bool = False
[docs]class Engine: """ The engine executes the registered Olive Steps, facilitate evaluation of the output models using provided evaluation criteria and produces output model(s). """ def __init__( self, config: Union[Dict[str, Any], EngineConfig] = None, search_strategy: Optional[SearchStrategy] = None, host: Optional[OliveSystem] = None, evaluator: Optional[OliveEvaluator] = None, ): self._config = validate_config(config, EngineConfig) if search_strategy is not None: self._search_strategy = search_strategy else: assert self._config.search_strategy is not None, "Search strategy must be provided" self._search_strategy = SearchStrategy(self._config.search_strategy) # default host if host is not None: self.host = host elif self._config.host is not None: self.host = self._config.host.create_system() else: self.host = LocalSystem() # Dictionary to keep track of separate hosts for a pass optionally provided by the user. self.hosts = {} # default evaluator self._evaluator = None if evaluator is not None: self._evaluator = evaluator elif self._config.evaluator is not None: self._evaluator = self._config.evaluator.create_evaluator() # Dictionary to keep track of separate evaluator for a pass optionally provided by the user. self._evaluators = {} # dictionary of passes self._passes = {} # list of pass names in the order they were registered self._pass_order = [] self._clean_pass_run_cache = {} self._initialized = False def initialize(self): """ Initialize engine state. This should be done before running the registered passes. """ cache_dir = self._config.cache_dir if self._config.clean_cache: clean_cache(cache_dir) if self._config.clean_evaluation_cache: clean_evaluation_cache(cache_dir) self._model_cache_path, self._run_cache_path, self._evaluation_cache_path = get_cache_sub_dirs(cache_dir) create_cache(cache_dir) # initialize counters # we do this before cleaning pass run caches to ensure we don't reuse model numbers even if the model was # deleted from the cache self._new_model_number = 0 model_jsons = list(self._model_cache_path.glob("*_*.json")) if len(model_jsons) > 0: self._new_model_number = max([int(json_file.stem.split("_")[0]) for json_file in model_jsons]) + 1 # initialize search spaces self._pass_search_spaces = [] for pass_name in self._pass_order: p = self._passes[pass_name] self._pass_search_spaces.append((pass_name, p.search_space())) # clean run cache if requested # removes all run cache for pass type and all children elements clean_run_cache = self._clean_pass_run_cache[pass_name] if clean_run_cache: clean_pass_run_cache(p.__class__.__name__, cache_dir) self._initialized = True
[docs] def register( self, p: Pass, name: str = None, host: OliveSystem = None, evaluator: OliveEvaluator = None, clean_run_cache: bool = False, ): """ Register a pass """ if name is not None: assert name not in self._passes, f"Pass with name {name} already registered" else: id = 0 while True: name = p.__class__.__name__ if id > 0: name = f"{name}_{id}" if name not in self._passes: break self._passes[name] = p self._pass_order.append(name) self.hosts[name] = host self._evaluators[name] = evaluator self._clean_pass_run_cache[name] = clean_run_cache
[docs] def run(self, input_model: OliveModel, verbose: bool = False): """ Run all the registered Olive passes on the input model and produce one or more candidate models. """ if not self._initialized: self.initialize() # hash the input model input_model_id = self._init_input_model(input_model) # get objective_dict evaluator = self.evaluator_for_pass(self._pass_order[-1]) objective_dict = self.resolve_objectives(input_model, input_model_id, evaluator.metrics, verbose) # initialize the search strategy self._search_strategy.initialize(self._pass_search_spaces, input_model_id, objective_dict) # record start time start_time = time.time() iter_num = 0 while True: iter_num += 1 # get the next step should_prune = False next_step = self._search_strategy.next_step() # if no more steps, break if next_step is None: break # get the model id of the first input model model_id = next_step["model_id"] if model_id == input_model_id: model = input_model else: model = self._load_model(model_id) if verbose: logger.info(f"Step {iter_num} with search point {next_step['search_point']} ...") # run all the passes in the step model_ids = [] for pass_id, pass_search_point in next_step["passes"]: if verbose: logger.info(f"Running pass {pass_id} with search point {pass_search_point} ...") if input_model.is_aml_model and not self.host_for_pass(pass_id).system_type == SystemType.AzureML: error_msg = "Azure ML model only supports AzureMLSystem for Olive Pass" logger.error(error_msg) raise Exception(error_msg) model, model_id = self._run_pass(pass_id, pass_search_point, model, model_id, verbose) if model == PRUNED_CONFIG: should_prune = True logger.info("Pruned") break model_ids.append(model_id) signal = {} if not should_prune: # evaluate the model try: signal = self._evaluate_model(model, model_id, self.evaluator_for_pass(pass_id), verbose) except Exception as e: logger.error(f"Evaluation failed: {e}") raise e if verbose: logger.info(f"Signal: {signal}") # record feedback signal self._search_strategy.record_feedback_signal(next_step["search_point"], signal, model_ids, should_prune) time_diff = time.time() - start_time self._search_strategy.check_exit_criteria(iter_num, time_diff, signal) # import json # for i, key in enumerate(self._search_strategy._search_results): # json.dump( # self._search_strategy._search_results[key].to_json(), open(f"search_results_{i}.json", "w"), indent=4 # ) return self._search_strategy.get_best_execution()
def resolve_objectives( self, input_model: OliveModel, input_model_id: str, metrics: List[Metric], verbose: bool = False ) -> Dict[str, Dict[str, Any]]: """ Return a dictionary of objectives and their higher_is_better and goal values. {objective_name: {"higher_is_better": bool, "goal": float}} """ goals = self.resolve_goals(input_model, input_model_id, metrics, verbose) objective_dict = { metric.name: {"higher_is_better": metric.higher_is_better, "goal": goals.get(metric.name)} for metric in metrics } return objective_dict def resolve_goals( self, input_model: OliveModel, input_model_id: str, metrics: List[Metric], verbose: bool = False ) -> Dict[str, float]: """ Resolve the goals of the given metrics into thresholds for the given model. """ goals = {} multipliers = {} for metric in metrics: if metric.goal is not None: goals[metric.name] = metric.goal multipliers[metric.name] = 1 if metric.higher_is_better else -1 if verbose and len(goals) > 0: logger.info(f"Resolving goals: {goals}") # compute baseline for input model if needed baseline = {} for _, goal in goals.items(): if goal.type != "threshold": assert self._evaluator is not None, "Default evaluator must be provided to resolve goals" if verbose: logger.info("Computing baseline for metrics ...") baseline = self._evaluate_model(input_model, input_model_id, self._evaluator, verbose=False) break if verbose and len(baseline) > 0: logger.info(f"Baseline: {baseline}") # resolve goals to thresholds resolved_goals = {} for name, goal in goals.items(): # TODO: make the logic cleaner if goal.type == "threshold": resolved_goals[name] = goal.value elif goal.type == "max-degradation": resolved_goals[name] = baseline[name] - multipliers[name] * goal.value elif goal.type == "min-improvement": resolved_goals[name] = baseline[name] + multipliers[name] * goal.value elif goal.type == "percent-max-degradation": resolved_goals[name] = baseline[name] * (1 - multipliers[name] * goal.value / 100) elif goal.type == "percent-min-improvement": resolved_goals[name] = baseline[name] * (1 + multipliers[name] * goal.value / 100) if verbose and len(resolved_goals) > 0: logger.info(f"Resolved goals: {resolved_goals}") return resolved_goals def host_for_pass(self, pass_id: str): host = self.hosts[pass_id] if host is None: return self.host return host def evaluator_for_pass(self, pass_id: str): """ Return evaluator for the given pass. """ e = self._evaluators[pass_id] if e is None: return self._evaluator return e def _get_new_model_number(self): """ Get a new model number. """ while True: new_model_number = self._new_model_number self._new_model_number += 1 if list(self._model_cache_path.glob(f"{new_model_number}_*.json")) == []: break return new_model_number def _cache_model(self, model: Union[OliveModel, str], model_id: str): """ Cache the model in the cache directory. """ if model == PRUNED_CONFIG: model_json = {} else: model_json = model.to_json() model_json_path = self._model_cache_path / f"{model_id}.json" try: json.dump(model_json, open(model_json_path, "w")) except Exception as e: logger.error(f"Failed to cache model: {e}") def _load_model(self, model_id: str) -> Union[OliveModel, str]: """ Load the model from the cache directory. """ model_json_path = self._model_cache_path / f"{model_id}.json" try: model_json = json.load(open(model_json_path, "r")) except Exception as e: logger.error(f"Failed to load model: {e}") return None if model_json == {}: return PRUNED_CONFIG model = ModelConfig.from_json(model_json).create_model() return model def _init_input_model(self, input_model: OliveModel): """ Initialize the input model. """ model_hash = hash_dict(input_model.to_json()) # cache the model self._cache_model(input_model, model_hash) return model_hash def _cache_run(self, pass_name: int, pass_config: dict, input_model_id: str, output_model_id: str): """ Cache the run in the cache directory. """ run_json = { "pass_name": pass_name, "pass_config": pass_config, "input_model_id": input_model_id, "output_model_id": output_model_id, } input_model_number = input_model_id.split("_")[0] run_json_path = self._run_cache_path / f"{pass_name}-{input_model_number}-{hash_dict(pass_config)}.json" try: json.dump(run_json, open(run_json_path, "w")) except Exception as e: logger.error(f"Failed to cache run: {e}") def _load_run(self, input_model_id: str, pass_name: int, pass_config: dict): """ Load the run from the cache directory. """ input_model_number = input_model_id.split("_")[0] run_json_path = self._run_cache_path / f"{pass_name}-{input_model_number}-{hash_dict(pass_config)}.json" if Path(run_json_path).exists(): try: run_json = json.load(open(run_json_path, "r")) output_model_id = run_json["output_model_id"] except Exception as e: logger.error(f"Failed to load run: {e}") output_model_id = None return output_model_id else: return None def _run_pass( self, pass_id: int, pass_search_point: dict, input_model: OliveModel, input_model_id: str, verbose: bool ): """ Run a pass on the input model. """ # pass p = self._passes[pass_id] pass_name = p.__class__.__name__ pass_config = p.config_at_search_point(pass_search_point) pass_config = p.serialize_config(pass_config) # load run from cache if it exists output_model_id = self._load_run(input_model_id, pass_name, pass_config) if output_model_id is not None: if verbose: logger.info("Loading model from cache ...") output_model = self._load_model(output_model_id) if output_model is not None: return output_model, output_model_id # new model id input_model_number = input_model_id.split("_")[0] output_model_id = f"{self._get_new_model_number()}_{pass_name}-{input_model_number}-{hash_dict(pass_config)}" output_model_path = str(self._model_cache_path / f"{output_model_id}") # prune if invalid search_point if not p.validate_search_point(pass_search_point): output_model = PRUNED_CONFIG else: # run pass try: host = self.host_for_pass(pass_id) output_model = host.run_pass(p, input_model, output_model_path, pass_search_point) except Exception: output_model = PRUNED_CONFIG # TODO: from the time being, we need to catch all exceptions to make the # search process robust. We need rethrow the exception only when # it is not pass specific. For example, for olive bugs and user errors logger.error("Pass run failed.", exc_info=True) # cache model self._cache_model(output_model, output_model_id) # cache run self._cache_run(pass_name, pass_config, input_model_id, output_model_id) return output_model, output_model_id def _cache_evaluation(self, model_id: str, signal: dict): """ Cache the evaluation in the cache directory. """ evaluation_json = { "model_id": model_id, "signal": signal, } evaluation_json_path = self._evaluation_cache_path / f"{model_id}.json" try: json.dump(evaluation_json, open(evaluation_json_path, "w")) except Exception as e: logger.error(f"Failed to cache evaluation: {e}") def _load_evaluation(self, model_id: str): """ Load the evaluation from the cache directory. """ evaluation_json_path = self._evaluation_cache_path / f"{model_id}.json" if Path(evaluation_json_path).exists(): try: evaluation_json = json.load(open(evaluation_json_path, "r")) signal = evaluation_json["signal"] except Exception as e: logger.error(f"Failed to load evaluation: {e}") signal = None return signal else: return None def _evaluate_model(self, model: OliveModel, model_id: str, evaluator: OliveEvaluator, verbose: bool): """ Evaluate a model. """ if verbose: logger.info("Evaluating output model ...") # load evaluation from cache if it exists signal = self._load_evaluation(model_id) if signal is not None: if verbose: logger.info("Loading evaluation from cache ...") return signal # evaluate model signal = evaluator.evaluate(model) # cache evaluation self._cache_evaluation(model_id, signal) return signal