Coverage for mlos_bench/mlos_bench/environments/local/local_fileshare_env.py: 87%
55 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side Environment to run scripts locally and upload/download data to the
6shared storage.
7"""
9import logging
10from datetime import datetime
11from string import Template
12from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple
14from mlos_bench.environments.local.local_env import LocalEnv
15from mlos_bench.environments.status import Status
16from mlos_bench.services.base_service import Service
17from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
18from mlos_bench.services.types.local_exec_type import SupportsLocalExec
19from mlos_bench.tunables.tunable import TunableValue
20from mlos_bench.tunables.tunable_groups import TunableGroups
22_LOG = logging.getLogger(__name__)
25class LocalFileShareEnv(LocalEnv):
26 """Scheduler-side Environment that runs scripts locally and uploads/downloads data
27 to the shared file storage.
28 """
30 def __init__( # pylint: disable=too-many-arguments
31 self,
32 *,
33 name: str,
34 config: dict,
35 global_config: Optional[dict] = None,
36 tunables: Optional[TunableGroups] = None,
37 service: Optional[Service] = None,
38 ):
39 """
40 Create a new application environment with a given config.
42 Parameters
43 ----------
44 name: str
45 Human-readable name of the environment.
46 config : dict
47 Free-format dictionary that contains the benchmark environment
48 configuration. Each config must have at least the "tunable_params"
49 and the "const_args" sections.
50 `LocalFileShareEnv` must also have at least some of the following
51 parameters: {setup, upload, run, download, teardown,
52 dump_params_file, read_results_file}
53 global_config : dict
54 Free-format dictionary of global parameters (e.g., security credentials)
55 to be mixed in into the "const_args" section of the local config.
56 tunables : TunableGroups
57 A collection of tunable parameters for *all* environments.
58 service: Service
59 An optional service object (e.g., providing methods to
60 deploy or reboot a VM, etc.).
61 """
62 super().__init__(
63 name=name,
64 config=config,
65 global_config=global_config,
66 tunables=tunables,
67 service=service,
68 )
70 assert self._service is not None and isinstance(
71 self._service, SupportsLocalExec
72 ), "LocalEnv requires a service that supports local execution"
73 self._local_exec_service: SupportsLocalExec = self._service
75 assert self._service is not None and isinstance(
76 self._service, SupportsFileShareOps
77 ), "LocalEnv requires a service that supports file upload/download operations"
78 self._file_share_service: SupportsFileShareOps = self._service
80 self._upload = self._template_from_to("upload")
81 self._download = self._template_from_to("download")
83 def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]:
84 """Convert a list of {"from": "...", "to": "..."} to a list of pairs of
85 string.Template objects so that we can plug in self._params into it later.
86 """
87 return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])]
89 @staticmethod
90 def _expand(
91 from_to: Iterable[Tuple[Template, Template]],
92 params: Mapping[str, TunableValue],
93 ) -> Generator[Tuple[str, str], None, None]:
94 """
95 Substitute $var parameters in from/to path templates.
97 Return a generator of (str, str) pairs of paths.
98 """
99 return (
100 (path_from.safe_substitute(params), path_to.safe_substitute(params))
101 for (path_from, path_to) in from_to
102 )
104 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
105 """
106 Run setup scripts locally and upload the scripts and data to the shared storage.
108 Parameters
109 ----------
110 tunables : TunableGroups
111 A collection of tunable OS and application parameters along with their
112 values. In a local environment these could be used to prepare a config
113 file on the scheduler prior to transferring it to the remote environment,
114 for instance.
115 global_config : dict
116 Free-format dictionary of global parameters of the environment
117 that are not used in the optimization process.
119 Returns
120 -------
121 is_success : bool
122 True if operation is successful, false otherwise.
123 """
124 self._is_ready = super().setup(tunables, global_config)
125 if self._is_ready:
126 assert self._temp_dir is not None
127 params = self._get_env_params(restrict=False)
128 params["PWD"] = self._temp_dir
129 for path_from, path_to in self._expand(self._upload, params):
130 self._file_share_service.upload(
131 self._params,
132 self._config_loader_service.resolve_path(
133 path_from,
134 extra_paths=[self._temp_dir],
135 ),
136 path_to,
137 )
138 return self._is_ready
140 def _download_files(self, ignore_missing: bool = False) -> None:
141 """
142 Download files from the shared storage.
144 Parameters
145 ----------
146 ignore_missing : bool
147 If True, raise an exception when some file cannot be downloaded.
148 If False, proceed with downloading other files and log a warning.
149 """
150 assert self._temp_dir is not None
151 params = self._get_env_params(restrict=False)
152 params["PWD"] = self._temp_dir
153 for path_from, path_to in self._expand(self._download, params):
154 try:
155 self._file_share_service.download(
156 self._params,
157 path_from,
158 self._config_loader_service.resolve_path(
159 path_to,
160 extra_paths=[self._temp_dir],
161 ),
162 )
163 except FileNotFoundError as ex:
164 _LOG.warning("Cannot download: %s", path_from)
165 if not ignore_missing:
166 raise ex
167 except Exception as ex:
168 _LOG.exception("Cannot download %s to %s", path_from, path_to)
169 raise ex
171 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
172 """
173 Download benchmark results from the shared storage and run post-processing
174 scripts locally.
176 Returns
177 -------
178 (status, timestamp, output) : (Status, datetime.datetime, dict)
179 3-tuple of (Status, timestamp, output) values, where `output` is a dict
180 with the results or None if the status is not COMPLETED.
181 If run script is a benchmark, then the score is usually expected to
182 be in the `score` field.
183 """
184 self._download_files()
185 return super().run()
187 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
188 self._download_files(ignore_missing=True)
189 return super().status()