Coverage for mlos_bench/mlos_bench/environments/local/local_fileshare_env.py: 88%
56 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Scheduler-side Environment to run scripts locally
7and upload/download data to the shared storage.
8"""
10import logging
12from datetime import datetime
13from string import Template
14from typing import Any, Dict, List, Generator, Iterable, Mapping, Optional, Tuple
16from mlos_bench.services.base_service import Service
17from mlos_bench.services.types.local_exec_type import SupportsLocalExec
18from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
19from mlos_bench.environments.status import Status
20from mlos_bench.environments.local.local_env import LocalEnv
21from mlos_bench.tunables.tunable import TunableValue
22from mlos_bench.tunables.tunable_groups import TunableGroups
24_LOG = logging.getLogger(__name__)
27class LocalFileShareEnv(LocalEnv):
28 """
29 Scheduler-side Environment that runs scripts locally
30 and uploads/downloads data to the shared file storage.
31 """
33 def __init__(self,
34 *,
35 name: str,
36 config: dict,
37 global_config: Optional[dict] = None,
38 tunables: Optional[TunableGroups] = None,
39 service: Optional[Service] = None):
40 """
41 Create a new application environment with a given config.
43 Parameters
44 ----------
45 name: str
46 Human-readable name of the environment.
47 config : dict
48 Free-format dictionary that contains the benchmark environment
49 configuration. Each config must have at least the "tunable_params"
50 and the "const_args" sections.
51 `LocalFileShareEnv` must also have at least some of the following
52 parameters: {setup, upload, run, download, teardown,
53 dump_params_file, read_results_file}
54 global_config : dict
55 Free-format dictionary of global parameters (e.g., security credentials)
56 to be mixed in into the "const_args" section of the local config.
57 tunables : TunableGroups
58 A collection of tunable parameters for *all* environments.
59 service: Service
60 An optional service object (e.g., providing methods to
61 deploy or reboot a VM, etc.).
62 """
63 super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service)
65 assert self._service is not None and isinstance(self._service, SupportsLocalExec), \
66 "LocalEnv requires a service that supports local execution"
67 self._local_exec_service: SupportsLocalExec = self._service
69 assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \
70 "LocalEnv requires a service that supports file upload/download operations"
71 self._file_share_service: SupportsFileShareOps = self._service
73 self._upload = self._template_from_to("upload")
74 self._download = self._template_from_to("download")
76 def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]:
77 """
78 Convert a list of {"from": "...", "to": "..."} to a list of pairs
79 of string.Template objects so that we can plug in self._params into it later.
80 """
81 return [
82 (Template(d['from']), Template(d['to']))
83 for d in self.config.get(config_key, [])
84 ]
86 @staticmethod
87 def _expand(from_to: Iterable[Tuple[Template, Template]],
88 params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]:
89 """
90 Substitute $var parameters in from/to path templates.
91 Return a generator of (str, str) pairs of paths.
92 """
93 return (
94 (path_from.safe_substitute(params), path_to.safe_substitute(params))
95 for (path_from, path_to) in from_to
96 )
98 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
99 """
100 Run setup scripts locally and upload the scripts and data to the shared storage.
102 Parameters
103 ----------
104 tunables : TunableGroups
105 A collection of tunable OS and application parameters along with their
106 values. In a local environment these could be used to prepare a config
107 file on the scheduler prior to transferring it to the remote environment,
108 for instance.
109 global_config : dict
110 Free-format dictionary of global parameters of the environment
111 that are not used in the optimization process.
113 Returns
114 -------
115 is_success : bool
116 True if operation is successful, false otherwise.
117 """
118 self._is_ready = super().setup(tunables, global_config)
119 if self._is_ready:
120 assert self._temp_dir is not None
121 params = self._get_env_params(restrict=False)
122 params["PWD"] = self._temp_dir
123 for (path_from, path_to) in self._expand(self._upload, params):
124 self._file_share_service.upload(self._params, self._config_loader_service.resolve_path(
125 path_from, extra_paths=[self._temp_dir]), path_to)
126 return self._is_ready
128 def _download_files(self, ignore_missing: bool = False) -> None:
129 """
130 Download files from the shared storage.
132 Parameters
133 ----------
134 ignore_missing : bool
135 If True, raise an exception when some file cannot be downloaded.
136 If False, proceed with downloading other files and log a warning.
137 """
138 assert self._temp_dir is not None
139 params = self._get_env_params(restrict=False)
140 params["PWD"] = self._temp_dir
141 for (path_from, path_to) in self._expand(self._download, params):
142 try:
143 self._file_share_service.download(self._params,
144 path_from, self._config_loader_service.resolve_path(
145 path_to, extra_paths=[self._temp_dir]))
146 except FileNotFoundError as ex:
147 _LOG.warning("Cannot download: %s", path_from)
148 if not ignore_missing:
149 raise ex
150 except Exception as ex:
151 _LOG.exception("Cannot download %s to %s", path_from, path_to)
152 raise ex
154 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
155 """
156 Download benchmark results from the shared storage
157 and run post-processing scripts locally.
159 Returns
160 -------
161 (status, timestamp, output) : (Status, datetime, dict)
162 3-tuple of (Status, timestamp, output) values, where `output` is a dict
163 with the results or None if the status is not COMPLETED.
164 If run script is a benchmark, then the score is usually expected to
165 be in the `score` field.
166 """
167 self._download_files()
168 return super().run()
170 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
171 self._download_files(ignore_missing=True)
172 return super().status()