Source code for olive.systems.docker.docker_system

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import copy
import json
import logging
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import docker
from docker.errors import BuildError, ContainerError

import olive.systems.docker.utils as docker_utils
from olive.common.config_utils import validate_config
from olive.evaluator.metric import Metric, MetricResult
from olive.hardware import Device
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import ModelConfig
from olive.passes import Pass
from olive.systems.common import LocalDockerConfig, SystemType
from olive.systems.olive_system import OliveSystem

logger = logging.getLogger(__name__)


[docs]class DockerSystem(OliveSystem): system_type = SystemType.Docker BASE_DOCKERFILE = "Dockerfile" def __init__( self, local_docker_config: Union[Dict[str, Any], LocalDockerConfig] = None, accelerators: List[str] = None, is_dev: bool = False, olive_managed_env: bool = False, requirements_file: Union[Path, str] = None, ): super().__init__(accelerators=accelerators, olive_managed_env=olive_managed_env) logger.info("Initializing Docker System...") self.is_dev = is_dev self.requirements_file = requirements_file if local_docker_config and olive_managed_env: raise ValueError("Olive managed environment is not supported if local_docker_config is provided.") if local_docker_config: self.docker_client = docker.from_env() local_docker_config = validate_config(local_docker_config, LocalDockerConfig) self.run_params = local_docker_config.run_params try: self.image = self.docker_client.images.get(local_docker_config.image_name) logger.info(f"Image {local_docker_config.image_name} found") except docker.errors.ImageNotFound: if local_docker_config.build_context_path and local_docker_config.dockerfile: build_context_path = local_docker_config.build_context_path logger.info(f"Building image from Dockerfile {build_context_path}/{local_docker_config.dockerfile}") try: self.image, build_logs = self.docker_client.images.build( path=build_context_path, dockerfile=local_docker_config.dockerfile, tag=local_docker_config.image_name, buildargs=local_docker_config.build_args, ) logger.info(f"Image {local_docker_config.image_name} build successfully.") _print_docker_logs(build_logs, logging.DEBUG) except BuildError as e: logger.error(f"Image build failed with error: {e}") _print_docker_logs(e.build_log, logging.ERROR) raise elif local_docker_config.requirements_file_path: logger.info( f"Building image from Olive default Dockerfile with buildargs {local_docker_config.build_args} " f"requirements.txt {local_docker_config.requirements_file_path}" ) dockerfile_path = str(Path(__file__).resolve().parent / self.BASE_DOCKERFILE) with tempfile.TemporaryDirectory() as tempdir: build_context_path = tempdir shutil.copy2(dockerfile_path, build_context_path) shutil.copy2(local_docker_config.requirements_file_path, build_context_path) try: self.image, build_logs = self.docker_client.images.build( path=build_context_path, dockerfile=self.BASE_DOCKERFILE, tag=local_docker_config.image_name, buildargs=local_docker_config.build_args, ) logger.info(f"Image {local_docker_config.image_name} build successfully.") _print_docker_logs(build_logs, logging.DEBUG) except BuildError as e: logger.error(f"Image build failed with error: {e}") _print_docker_logs(e.build_log, logging.ERROR) raise def run_pass( self, the_pass: Pass, model: ModelConfig, data_root: str, output_model_path: str, point: Optional[Dict[str, Any]] = None, ) -> ModelConfig: """Run the pass on the model at a specific point in the search space.""" logger.warning("DockerSystem.run_pass is not implemented yet.") raise NotImplementedError def evaluate_model( self, model_config: ModelConfig, data_root: str, metrics: List[Metric], accelerator: AcceleratorSpec ) -> Dict[str, Any]: container_root_path = Path("/olive-ws/") with tempfile.TemporaryDirectory() as tempdir: metrics_res = None metric_json = self._run_container( tempdir, model_config, data_root, metrics, accelerator, container_root_path ) if metric_json.is_file(): with metric_json.open() as f: metrics_res = json.load(f) return MetricResult.parse_obj(metrics_res) def _run_container( self, tempdir, model_config: ModelConfig, data_root: str, metrics: List[Metric], accelerator: AcceleratorSpec, container_root_path: Path, ): eval_output_path = "eval_output" eval_output_name = "eval_res.json" volumes_list = [] eval_file_mount_path, eval_file_mount_str = docker_utils.create_eval_script_mount(container_root_path) volumes_list.append(eval_file_mount_str) if self.is_dev: dev_mount_path, dev_mount_str = docker_utils.create_dev_mount(tempdir, container_root_path) volumes_list.append(dev_mount_str) model_config_copy = copy.deepcopy(model_config) model_mounts, model_mount_str_list = docker_utils.create_model_mount( model_config=model_config_copy, container_root_path=container_root_path ) volumes_list += model_mount_str_list metrics_copy = copy.deepcopy(metrics) volumes_list = docker_utils.create_metric_volumes_list( data_root=data_root, metrics=metrics_copy, container_root_path=container_root_path, mount_list=volumes_list, ) config_mount_path, config_file_mount_str = docker_utils.create_config_file( tempdir=tempdir, model_config=model_config_copy, metrics=metrics_copy, container_root_path=container_root_path, model_mounts=model_mounts, ) volumes_list.append(config_file_mount_str) output_local_path, output_mount_path, output_mount_str = docker_utils.create_output_mount( tempdir=tempdir, docker_eval_output_path=eval_output_path, container_root_path=container_root_path, ) volumes_list.append(output_mount_str) logger.debug(f"The volumes list is {volumes_list}") eval_command = docker_utils.create_evaluate_command( eval_script_path=eval_file_mount_path, config_path=config_mount_path, output_path=output_mount_path, output_name=eval_output_name, accelerator=accelerator, ) run_command = docker_utils.create_run_command(run_params=self.run_params) environment = run_command.pop("environment", {}) envs_dict = {"PYTHONPYCACHEPREFIX": "/tmp"} for k, v in envs_dict.items(): if isinstance(environment, list): environment = {env.split("=")[0]: env.split("=")[1] for env in environment} elif isinstance(environment, dict) and not environment.get(k): environment[k] = v logger.debug(f"Running container with eval command: {eval_command}") if accelerator.accelerator_type == Device.GPU: container = self.docker_client.containers.run( image=self.image, command=eval_command, volumes=volumes_list, detach=True, environment=environment, device_requests=[docker.types.DeviceRequest(capabilities=[["gpu"]])], **run_command, ) else: container = self.docker_client.containers.run( image=self.image, command=eval_command, volumes=volumes_list, detach=True, environment=environment, **run_command, ) docker_logs = [] for line in container.logs(stream=True): # containers.logs can accept stdout/stderr as arguments, but it doesn't work # as we cannot ensure that all the logs will be printed in the correct channel(out/err) # so, we collect all the logs and print them in the end if there is an error. log = line.decode().strip() logger.debug(log) docker_logs.append(log) exit_code = container.wait()["StatusCode"] container.remove() if exit_code != 0: error_msg = "\n".join(docker_logs) raise ContainerError( container, exit_code, eval_command, self.image, f"Docker container evaluation failed with: {error_msg}" ) logger.debug("Docker container evaluation completed successfully") return Path(output_local_path) / f"{eval_output_name}" def remove(self): self.docker_client.images.remove(self.image.tags[0], force=True) logger.info(f"Image {self.image.tags[0]} removed successfully.")
def _print_docker_logs(logs, level=logging.DEBUG): msgs = [] for log in logs: if "stream" in log: msgs.append(str(log["stream"]).strip()) else: msgs.append(str(log).strip()) message = "\n".join(msgs) logger.log(level, message)