Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%

139 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A hierarchy of benchmark environments. 

7""" 

8 

9import abc 

10import json 

11import logging 

12from datetime import datetime 

13from types import TracebackType 

14from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union 

15from typing_extensions import Literal 

16 

17from pytz import UTC 

18 

19from mlos_bench.config.schemas import ConfigSchema 

20from mlos_bench.dict_templater import DictTemplater 

21from mlos_bench.environments.status import Status 

22from mlos_bench.services.base_service import Service 

23from mlos_bench.tunables.tunable import TunableValue 

24from mlos_bench.tunables.tunable_groups import TunableGroups 

25from mlos_bench.util import instantiate_from_config, merge_parameters 

26 

27if TYPE_CHECKING: 

28 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

29 

30_LOG = logging.getLogger(__name__) 

31 

32 

33class Environment(metaclass=abc.ABCMeta): 

34 # pylint: disable=too-many-instance-attributes 

35 """ 

36 An abstract base of all benchmark environments. 

37 """ 

38 

39 @classmethod 

40 def new(cls, 

41 *, 

42 env_name: str, 

43 class_name: str, 

44 config: dict, 

45 global_config: Optional[dict] = None, 

46 tunables: Optional[TunableGroups] = None, 

47 service: Optional[Service] = None, 

48 ) -> "Environment": 

49 """ 

50 Factory method for a new environment with a given config. 

51 

52 Parameters 

53 ---------- 

54 env_name: str 

55 Human-readable name of the environment. 

56 class_name: str 

57 FQN of a Python class to instantiate, e.g., 

58 "mlos_bench.environments.remote.HostEnv". 

59 Must be derived from the `Environment` class. 

60 config : dict 

61 Free-format dictionary that contains the benchmark environment 

62 configuration. It will be passed as a constructor parameter of 

63 the class specified by `name`. 

64 global_config : dict 

65 Free-format dictionary of global parameters (e.g., security credentials) 

66 to be mixed in into the "const_args" section of the local config. 

67 tunables : TunableGroups 

68 A collection of groups of tunable parameters for all environments. 

69 service: Service 

70 An optional service object (e.g., providing methods to 

71 deploy or reboot a VM/Host, etc.). 

72 

73 Returns 

74 ------- 

75 env : Environment 

76 An instance of the `Environment` class initialized with `config`. 

77 """ 

78 assert issubclass(cls, Environment) 

79 return instantiate_from_config( 

80 cls, 

81 class_name, 

82 name=env_name, 

83 config=config, 

84 global_config=global_config, 

85 tunables=tunables, 

86 service=service 

87 ) 

88 

89 def __init__(self, 

90 *, 

91 name: str, 

92 config: dict, 

93 global_config: Optional[dict] = None, 

94 tunables: Optional[TunableGroups] = None, 

95 service: Optional[Service] = None): 

96 """ 

97 Create a new environment with a given config. 

98 

99 Parameters 

100 ---------- 

101 name: str 

102 Human-readable name of the environment. 

103 config : dict 

104 Free-format dictionary that contains the benchmark environment 

105 configuration. Each config must have at least the "tunable_params" 

106 and the "const_args" sections. 

107 global_config : dict 

108 Free-format dictionary of global parameters (e.g., security credentials) 

109 to be mixed in into the "const_args" section of the local config. 

110 tunables : TunableGroups 

111 A collection of groups of tunable parameters for all environments. 

112 service: Service 

113 An optional service object (e.g., providing methods to 

114 deploy or reboot a VM/Host, etc.). 

115 """ 

116 self._validate_json_config(config, name) 

117 self.name = name 

118 self.config = config 

119 self._service = service 

120 self._service_context: Optional[Service] = None 

121 self._is_ready = False 

122 self._in_context = False 

123 self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) 

124 

125 if _LOG.isEnabledFor(logging.DEBUG): 

126 _LOG.debug("Environment: '%s' Service: %s", name, 

127 self._service.pprint() if self._service else None) 

128 

129 if tunables is None: 

130 _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) 

131 tunables = TunableGroups() 

132 

133 groups = self._expand_groups( 

134 config.get("tunable_params", []), 

135 (global_config or {}).get("tunable_params_map", {})) 

136 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) 

137 

138 self._tunable_params = tunables.subgroup(groups) 

139 

140 # If a parameter comes from the tunables, do not require it in the const_args or globals 

141 req_args = ( 

142 set(config.get("required_args", [])) - 

143 set(self._tunable_params.get_param_values().keys()) 

144 ) 

145 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) 

146 self._const_args = self._expand_vars(self._const_args, global_config or {}) 

147 

148 self._params = self._combine_tunables(self._tunable_params) 

149 _LOG.debug("Parameters for '%s' :: %s", name, self._params) 

150 

151 if _LOG.isEnabledFor(logging.DEBUG): 

152 _LOG.debug("Config for: '%s'\n%s", 

153 name, json.dumps(self.config, indent=2)) 

154 

155 def _validate_json_config(self, config: dict, name: str) -> None: 

156 """ 

157 Reconstructs a basic json config that this class might have been 

158 instantiated from in order to validate configs provided outside the 

159 file loading mechanism. 

160 """ 

161 json_config: dict = { 

162 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

163 } 

164 if name: 

165 json_config["name"] = name 

166 if config: 

167 json_config["config"] = config 

168 ConfigSchema.ENVIRONMENT.validate(json_config) 

169 

170 @staticmethod 

171 def _expand_groups(groups: Iterable[str], 

172 groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: 

173 """ 

174 Expand `$tunable_group` into actual names of the tunable groups. 

175 

176 Parameters 

177 ---------- 

178 groups : List[str] 

179 Names of the groups of tunables, maybe with `$` prefix (subject to expansion). 

180 groups_exp : dict 

181 A dictionary that maps dollar variables for tunable groups to the lists 

182 of actual tunable groups IDs. 

183 

184 Returns 

185 ------- 

186 groups : List[str] 

187 A flat list of tunable groups IDs for the environment. 

188 """ 

189 res: List[str] = [] 

190 for grp in groups: 

191 if grp[:1] == "$": 

192 tunable_group_name = grp[1:] 

193 if tunable_group_name not in groups_exp: 

194 raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") 

195 add_groups = groups_exp[tunable_group_name] 

196 res += [add_groups] if isinstance(add_groups, str) else add_groups 

197 else: 

198 res.append(grp) 

199 return res 

200 

201 @staticmethod 

202 def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: 

203 """ 

204 Expand `$var` into actual values of the variables. 

205 """ 

206 return DictTemplater(params).expand_vars(extra_source_dict=global_config) 

207 

208 @property 

209 def _config_loader_service(self) -> "SupportsConfigLoading": 

210 assert self._service is not None 

211 return self._service.config_loader_service 

212 

213 def __enter__(self) -> 'Environment': 

214 """ 

215 Enter the environment's benchmarking context. 

216 """ 

217 _LOG.debug("Environment START :: %s", self) 

218 assert not self._in_context 

219 if self._service: 

220 self._service_context = self._service.__enter__() 

221 self._in_context = True 

222 return self 

223 

224 def __exit__(self, ex_type: Optional[Type[BaseException]], 

225 ex_val: Optional[BaseException], 

226 ex_tb: Optional[TracebackType]) -> Literal[False]: 

227 """ 

228 Exit the context of the benchmarking environment. 

229 """ 

230 ex_throw = None 

231 if ex_val is None: 

232 _LOG.debug("Environment END :: %s", self) 

233 else: 

234 assert ex_type and ex_val 

235 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

236 assert self._in_context 

237 if self._service_context: 

238 try: 

239 self._service_context.__exit__(ex_type, ex_val, ex_tb) 

240 # pylint: disable=broad-exception-caught 

241 except Exception as ex: 

242 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) 

243 ex_throw = ex 

244 finally: 

245 self._service_context = None 

246 self._in_context = False 

247 if ex_throw: 

248 raise ex_throw 

249 return False # Do not suppress exceptions 

250 

251 def __str__(self) -> str: 

252 return self.name 

253 

254 def __repr__(self) -> str: 

255 return f"{self.__class__.__name__} :: '{self.name}'" 

256 

257 def pprint(self, indent: int = 4, level: int = 0) -> str: 

258 """ 

259 Pretty-print the environment configuration. 

260 For composite environments, print all children environments as well. 

261 

262 Parameters 

263 ---------- 

264 indent : int 

265 Number of spaces to indent the output. Default is 4. 

266 level : int 

267 Current level of indentation. Default is 0. 

268 

269 Returns 

270 ------- 

271 pretty : str 

272 Pretty-printed environment configuration. 

273 Default output is the same as `__repr__`. 

274 """ 

275 return f'{" " * indent * level}{repr(self)}' 

276 

277 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: 

278 """ 

279 Plug tunable values into the base config. If the tunable group is unknown, 

280 ignore it (it might belong to another environment). This method should 

281 never mutate the original config or the tunables. 

282 

283 Parameters 

284 ---------- 

285 tunables : TunableGroups 

286 A collection of groups of tunable parameters 

287 along with the parameters' values. 

288 

289 Returns 

290 ------- 

291 params : Dict[str, Union[int, float, str]] 

292 Free-format dictionary that contains the new environment configuration. 

293 """ 

294 return tunables.get_param_values( 

295 group_names=list(self._tunable_params.get_covariant_group_names()), 

296 into_params=self._const_args.copy()) 

297 

298 @property 

299 def tunable_params(self) -> TunableGroups: 

300 """ 

301 Get the configuration space of the given environment. 

302 

303 Returns 

304 ------- 

305 tunables : TunableGroups 

306 A collection of covariant groups of tunable parameters. 

307 """ 

308 return self._tunable_params 

309 

310 @property 

311 def parameters(self) -> Dict[str, TunableValue]: 

312 """ 

313 Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`). 

314 Note that before `.setup()` is called, all tunables will be set to None. 

315 

316 Returns 

317 ------- 

318 parameters : Dict[str, TunableValue] 

319 Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`). 

320 """ 

321 return self._params 

322 

323 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

324 """ 

325 Set up a new benchmark environment, if necessary. This method must be 

326 idempotent, i.e., calling it several times in a row should be 

327 equivalent to a single call. 

328 

329 Parameters 

330 ---------- 

331 tunables : TunableGroups 

332 A collection of tunable parameters along with their values. 

333 global_config : dict 

334 Free-format dictionary of global parameters of the environment 

335 that are not used in the optimization process. 

336 

337 Returns 

338 ------- 

339 is_success : bool 

340 True if operation is successful, false otherwise. 

341 """ 

342 _LOG.info("Setup %s :: %s", self, tunables) 

343 assert isinstance(tunables, TunableGroups) 

344 

345 # Make sure we create a context before invoking setup/run/status/teardown 

346 assert self._in_context 

347 

348 # Assign new values to the environment's tunable parameters: 

349 groups = list(self._tunable_params.get_covariant_group_names()) 

350 self._tunable_params.assign(tunables.get_param_values(groups)) 

351 

352 # Write to the log whether the environment needs to be reset. 

353 # (Derived classes still have to check `self._tunable_params.is_updated()`). 

354 is_updated = self._tunable_params.is_updated() 

355 if _LOG.isEnabledFor(logging.DEBUG): 

356 _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { 

357 name: self._tunable_params.is_updated([name]) 

358 for name in self._tunable_params.get_covariant_group_names() 

359 }) 

360 else: 

361 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) 

362 

363 # Combine tunables, const_args, and global config into `self._params`: 

364 self._params = self._combine_tunables(tunables) 

365 merge_parameters(dest=self._params, source=global_config) 

366 

367 if _LOG.isEnabledFor(logging.DEBUG): 

368 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2)) 

369 

370 return True 

371 

372 def teardown(self) -> None: 

373 """ 

374 Tear down the benchmark environment. This method must be idempotent, 

375 i.e., calling it several times in a row should be equivalent to a 

376 single call. 

377 """ 

378 _LOG.info("Teardown %s", self) 

379 # Make sure we create a context before invoking setup/run/status/teardown 

380 assert self._in_context 

381 self._is_ready = False 

382 

383 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

384 """ 

385 Execute the run script for this environment. 

386 

387 For instance, this may start a new experiment, download results, reconfigure 

388 the environment, etc. Details are configurable via the environment config. 

389 

390 Returns 

391 ------- 

392 (status, timestamp, output) : (Status, datetime, dict) 

393 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

394 with the results or None if the status is not COMPLETED. 

395 If run script is a benchmark, then the score is usually expected to 

396 be in the `score` field. 

397 """ 

398 # Make sure we create a context before invoking setup/run/status/teardown 

399 assert self._in_context 

400 (status, timestamp, _) = self.status() 

401 return (status, timestamp, None) 

402 

403 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

404 """ 

405 Check the status of the benchmark environment. 

406 

407 Returns 

408 ------- 

409 (benchmark_status, timestamp, telemetry) : (Status, datetime, list) 

410 3-tuple of (benchmark status, timestamp, telemetry) values. 

411 `timestamp` is UTC time stamp of the status; it's current time by default. 

412 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

413 """ 

414 # Make sure we create a context before invoking setup/run/status/teardown 

415 assert self._in_context 

416 timestamp = datetime.now(UTC) 

417 if self._is_ready: 

418 return (Status.READY, timestamp, []) 

419 _LOG.warning("Environment not ready: %s", self) 

420 return (Status.PENDING, timestamp, [])