Source code for olive.systems.python_environment.python_environment_system

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
import logging
import os
import platform
import shutil
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from olive.common.constants import OS
from olive.common.utils import run_subprocess
from olive.evaluator.metric_result import MetricResult
from olive.model import ModelConfig
from olive.systems.common import AcceleratorConfig, SystemType
from olive.systems.olive_system import OliveSystem
from olive.systems.system_config import PythonEnvironmentTargetUserConfig
from olive.systems.utils import create_new_environ, get_package_name_from_ep, run_available_providers_runner

if TYPE_CHECKING:
    from olive.evaluator.olive_evaluator import OliveEvaluatorConfig
    from olive.hardware.accelerator import AcceleratorSpec
    from olive.passes.olive_pass import Pass


logger = logging.getLogger(__name__)


[docs]class PythonEnvironmentSystem(OliveSystem): system_type = SystemType.PythonEnvironment def __init__( self, python_environment_path: Union[Path, str] = None, environment_variables: Dict[str, str] = None, prepend_to_path: List[str] = None, accelerators: List[AcceleratorConfig] = None, olive_managed_env: bool = False, requirements_file: Union[Path, str] = None, hf_token: bool = None, ): if python_environment_path is None: raise ValueError("python_environment_path is required for PythonEnvironmentSystem.") super().__init__(accelerators=accelerators, hf_token=hf_token) self.config = PythonEnvironmentTargetUserConfig(**locals()) self.environ = create_new_environ( python_environment_path=python_environment_path, environment_variables=environment_variables, prepend_to_path=prepend_to_path, ) if olive_managed_env: if platform.system() == OS.LINUX: temp_dir = os.path.join(os.environ.get("HOME", ""), "tmp") if not os.path.exists(temp_dir): os.makedirs(temp_dir) self.environ["TMPDIR"] = temp_dir else: self.environ["TMPDIR"] = tempfile.TemporaryDirectory().name # pylint: disable=consider-using-with self.executable = shutil.which("python", path=self.environ["PATH"]) # available eps. This will be populated the first time self.get_supported_execution_providers() is called. # used for caching the available eps self.available_eps = None # path to inference script parent_dir = Path(__file__).parent.resolve() self.pass_runner_path = parent_dir / "pass_runner.py" self.evaluation_runner_path = parent_dir / "evaluation_runner.py" def _run_command(self, script_path: Path, config_jsons: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Run a script with the given config jsons and return the output json.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir_path = Path(tmp_dir).resolve() # command to run command = [self.executable, str(script_path)] # write config jsons to files for key, config_json in config_jsons.items(): config_json_path = tmp_dir_path / f"{key}.json" with config_json_path.open("w") as f: json.dump(config_json, f, indent=4) command.extend([f"--{key}", str(config_json_path)]) # add extra args for key, value in kwargs.items(): if value is None: continue command.extend([f"--{key}", str(value)]) # output path output_path = tmp_dir_path / "output.json" command.extend(["--output_path", str(output_path)]) # run the command _, stdout, _ = run_subprocess(command, env=self.environ, check=True) log_stdout(stdout) with output_path.open() as f: output = json.load(f) return output def run_pass( self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str, point: Optional[Dict[str, Any]] = None, ) -> ModelConfig: """Run the pass on the model at a specific point in the search space.""" point = point or {} config = the_pass.config_at_search_point(point) pass_config = the_pass.to_json(check_object=True) pass_config["config"].update(the_pass.serialize_config(config, check_object=True)) config_jsons = { "model_config": model_config.to_json(check_object=True), "pass_config": pass_config, } output_model_json = self._run_command( self.pass_runner_path, config_jsons, tempdir=tempfile.tempdir, output_model_path=output_model_path, ) return ModelConfig.parse_obj(output_model_json) def evaluate_model( self, model_config: ModelConfig, evaluator_config: "OliveEvaluatorConfig", accelerator: "AcceleratorSpec" ) -> MetricResult: """Evaluate the model.""" config_jsons = { "model_config": model_config.to_json(check_object=True), "evaluator_config": evaluator_config.to_json(check_object=True), "accelerator_config": accelerator.to_json(), } metric_results = self._run_command(self.evaluation_runner_path, config_jsons, tempdir=tempfile.tempdir) return MetricResult.parse_obj(metric_results) def get_supported_execution_providers(self) -> List[str]: """Get the available execution providers.""" if self.available_eps: return self.available_eps self.available_eps = run_available_providers_runner(self.environ) return self.available_eps def install_requirements(self, accelerator: "AcceleratorSpec"): """Install required packages.""" # install common packages common_requirements_file = Path(__file__).parent.resolve() / "common_requirements.txt" packages = [ f"-r {common_requirements_file}", ] if self.config.requirements_file: # install user requirements packages.append(f"-r {self.config.requirements_file}") # install onnxruntime package onnxruntime_package = get_package_name_from_ep(accelerator.execution_provider)[0] packages.append(onnxruntime_package) _, stdout, _ = run_subprocess( f"{self.executable} -m pip install --cache-dir {self.environ['TMPDIR']} {' '.join(packages)}", env=self.environ, check=True, ) log_stdout(stdout) _, stdout, _ = run_subprocess( f"{self.executable} -m pip show {onnxruntime_package}", env=self.environ, check=True ) log_stdout(stdout) def remove(self): vitual_env_path = Path(self.config.python_environment_path).resolve().parent try: shutil.rmtree(vitual_env_path) logger.info("Virtual environment '%s' removed.", vitual_env_path) except FileNotFoundError: pass if platform.system() == OS.LINUX: try: shutil.rmtree(self.environ["TMPDIR"]) logger.info("Temporary directory '%s' removed.", self.environ["TMPDIR"]) except FileNotFoundError: pass
def log_stdout(stdout: str): for line in stdout.splitlines(): logger.debug("%s", line.strip())