Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A hierarchy of benchmark environments.""" 

6 

7import abc 

8import json 

9import logging 

10from datetime import datetime 

11from types import TracebackType 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Dict, 

16 Iterable, 

17 List, 

18 Literal, 

19 Optional, 

20 Sequence, 

21 Tuple, 

22 Type, 

23 Union, 

24) 

25 

26from pytz import UTC 

27 

28from mlos_bench.config.schemas import ConfigSchema 

29from mlos_bench.dict_templater import DictTemplater 

30from mlos_bench.environments.status import Status 

31from mlos_bench.services.base_service import Service 

32from mlos_bench.tunables.tunable import TunableValue 

33from mlos_bench.tunables.tunable_groups import TunableGroups 

34from mlos_bench.util import instantiate_from_config, merge_parameters 

35 

36if TYPE_CHECKING: 

37 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class Environment(metaclass=abc.ABCMeta): 

43 # pylint: disable=too-many-instance-attributes 

44 """An abstract base of all benchmark environments.""" 

45 

46 @classmethod 

47 def new( # pylint: disable=too-many-arguments 

48 cls, 

49 *, 

50 env_name: str, 

51 class_name: str, 

52 config: dict, 

53 global_config: Optional[dict] = None, 

54 tunables: Optional[TunableGroups] = None, 

55 service: Optional[Service] = None, 

56 ) -> "Environment": 

57 """ 

58 Factory method for a new environment with a given config. 

59 

60 Parameters 

61 ---------- 

62 env_name: str 

63 Human-readable name of the environment. 

64 class_name: str 

65 FQN of a Python class to instantiate, e.g., 

66 "mlos_bench.environments.remote.HostEnv". 

67 Must be derived from the `Environment` class. 

68 config : dict 

69 Free-format dictionary that contains the benchmark environment 

70 configuration. It will be passed as a constructor parameter of 

71 the class specified by `name`. 

72 global_config : dict 

73 Free-format dictionary of global parameters (e.g., security credentials) 

74 to be mixed in into the "const_args" section of the local config. 

75 tunables : TunableGroups 

76 A collection of groups of tunable parameters for all environments. 

77 service: Service 

78 An optional service object (e.g., providing methods to 

79 deploy or reboot a VM/Host, etc.). 

80 

81 Returns 

82 ------- 

83 env : Environment 

84 An instance of the `Environment` class initialized with `config`. 

85 """ 

86 assert issubclass(cls, Environment) 

87 return instantiate_from_config( 

88 cls, 

89 class_name, 

90 name=env_name, 

91 config=config, 

92 global_config=global_config, 

93 tunables=tunables, 

94 service=service, 

95 ) 

96 

97 def __init__( # pylint: disable=too-many-arguments 

98 self, 

99 *, 

100 name: str, 

101 config: dict, 

102 global_config: Optional[dict] = None, 

103 tunables: Optional[TunableGroups] = None, 

104 service: Optional[Service] = None, 

105 ): 

106 """ 

107 Create a new environment with a given config. 

108 

109 Parameters 

110 ---------- 

111 name: str 

112 Human-readable name of the environment. 

113 config : dict 

114 Free-format dictionary that contains the benchmark environment 

115 configuration. Each config must have at least the "tunable_params" 

116 and the "const_args" sections. 

117 global_config : dict 

118 Free-format dictionary of global parameters (e.g., security credentials) 

119 to be mixed in into the "const_args" section of the local config. 

120 tunables : TunableGroups 

121 A collection of groups of tunable parameters for all environments. 

122 service: Service 

123 An optional service object (e.g., providing methods to 

124 deploy or reboot a VM/Host, etc.). 

125 """ 

126 self._validate_json_config(config, name) 

127 self.name = name 

128 self.config = config 

129 self._service = service 

130 self._service_context: Optional[Service] = None 

131 self._is_ready = False 

132 self._in_context = False 

133 self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) 

134 

135 if _LOG.isEnabledFor(logging.DEBUG): 

136 _LOG.debug( 

137 "Environment: '%s' Service: %s", 

138 name, 

139 self._service.pprint() if self._service else None, 

140 ) 

141 

142 if tunables is None: 

143 _LOG.warning( 

144 ( 

145 "No tunables provided for %s. " 

146 "Tunable inheritance across composite environments may be broken." 

147 ), 

148 name, 

149 ) 

150 tunables = TunableGroups() 

151 

152 # TODO: add user docstrings for these in the module 

153 groups = self._expand_groups( 

154 config.get("tunable_params", []), 

155 (global_config or {}).get("tunable_params_map", {}), 

156 ) 

157 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) 

158 

159 self._tunable_params = tunables.subgroup(groups) 

160 

161 # If a parameter comes from the tunables, do not require it in the const_args or globals 

162 req_args = set(config.get("required_args", [])) - set( 

163 self._tunable_params.get_param_values().keys() 

164 ) 

165 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) 

166 self._const_args = self._expand_vars(self._const_args, global_config or {}) 

167 

168 self._params = self._combine_tunables(self._tunable_params) 

169 _LOG.debug("Parameters for '%s' :: %s", name, self._params) 

170 

171 if _LOG.isEnabledFor(logging.DEBUG): 

172 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) 

173 

174 def _validate_json_config(self, config: dict, name: str) -> None: 

175 """Reconstructs a basic json config that this class might have been instantiated 

176 from in order to validate configs provided outside the file loading 

177 mechanism. 

178 """ 

179 json_config: dict = { 

180 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

181 } 

182 if name: 

183 json_config["name"] = name 

184 if config: 

185 json_config["config"] = config 

186 ConfigSchema.ENVIRONMENT.validate(json_config) 

187 

188 @staticmethod 

189 def _expand_groups( 

190 groups: Iterable[str], 

191 groups_exp: Dict[str, Union[str, Sequence[str]]], 

192 ) -> List[str]: 

193 """ 

194 Expand `$tunable_group` into actual names of the tunable groups. 

195 

196 Parameters 

197 ---------- 

198 groups : List[str] 

199 Names of the groups of tunables, maybe with `$` prefix (subject to expansion). 

200 groups_exp : dict 

201 A dictionary that maps dollar variables for tunable groups to the lists 

202 of actual tunable groups IDs. 

203 

204 Returns 

205 ------- 

206 groups : List[str] 

207 A flat list of tunable groups IDs for the environment. 

208 """ 

209 res: List[str] = [] 

210 for grp in groups: 

211 if grp[:1] == "$": 

212 tunable_group_name = grp[1:] 

213 if tunable_group_name not in groups_exp: 

214 raise KeyError( 

215 ( 

216 f"Expected tunable group name ${tunable_group_name} " 

217 "undefined in {groups_exp}" 

218 ) 

219 ) 

220 add_groups = groups_exp[tunable_group_name] 

221 res += [add_groups] if isinstance(add_groups, str) else add_groups 

222 else: 

223 res.append(grp) 

224 return res 

225 

226 @staticmethod 

227 def _expand_vars( 

228 params: Dict[str, TunableValue], 

229 global_config: Dict[str, TunableValue], 

230 ) -> dict: 

231 """Expand `$var` into actual values of the variables.""" 

232 return DictTemplater(params).expand_vars(extra_source_dict=global_config) 

233 

234 @property 

235 def _config_loader_service(self) -> "SupportsConfigLoading": 

236 assert self._service is not None 

237 return self._service.config_loader_service 

238 

239 def __enter__(self) -> "Environment": 

240 """Enter the environment's benchmarking context.""" 

241 _LOG.debug("Environment START :: %s", self) 

242 assert not self._in_context 

243 if self._service: 

244 self._service_context = self._service.__enter__() 

245 self._in_context = True 

246 return self 

247 

248 def __exit__( 

249 self, 

250 ex_type: Optional[Type[BaseException]], 

251 ex_val: Optional[BaseException], 

252 ex_tb: Optional[TracebackType], 

253 ) -> Literal[False]: 

254 """Exit the context of the benchmarking environment.""" 

255 ex_throw = None 

256 if ex_val is None: 

257 _LOG.debug("Environment END :: %s", self) 

258 else: 

259 assert ex_type and ex_val 

260 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

261 assert self._in_context 

262 if self._service_context: 

263 try: 

264 self._service_context.__exit__(ex_type, ex_val, ex_tb) 

265 # pylint: disable=broad-exception-caught 

266 except Exception as ex: 

267 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) 

268 ex_throw = ex 

269 finally: 

270 self._service_context = None 

271 self._in_context = False 

272 if ex_throw: 

273 raise ex_throw 

274 return False # Do not suppress exceptions 

275 

276 def __str__(self) -> str: 

277 return self.name 

278 

279 def __repr__(self) -> str: 

280 return f"{self.__class__.__name__} :: '{self.name}'" 

281 

282 def pprint(self, indent: int = 4, level: int = 0) -> str: 

283 """ 

284 Pretty-print the environment configuration. For composite environments, print 

285 all children environments as well. 

286 

287 Parameters 

288 ---------- 

289 indent : int 

290 Number of spaces to indent the output. Default is 4. 

291 level : int 

292 Current level of indentation. Default is 0. 

293 

294 Returns 

295 ------- 

296 pretty : str 

297 Pretty-printed environment configuration. 

298 Default output is the same as `__repr__`. 

299 """ 

300 return f'{" " * indent * level}{repr(self)}' 

301 

302 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: 

303 """ 

304 Plug tunable values into the base config. If the tunable group is unknown, 

305 ignore it (it might belong to another environment). This method should never 

306 mutate the original config or the tunables. 

307 

308 Parameters 

309 ---------- 

310 tunables : TunableGroups 

311 A collection of groups of tunable parameters 

312 along with the parameters' values. 

313 

314 Returns 

315 ------- 

316 params : Dict[str, Union[int, float, str]] 

317 Free-format dictionary that contains the new environment configuration. 

318 """ 

319 return tunables.get_param_values( 

320 group_names=list(self._tunable_params.get_covariant_group_names()), 

321 into_params=self._const_args.copy(), 

322 ) 

323 

324 @property 

325 def tunable_params(self) -> TunableGroups: 

326 """ 

327 Get the configuration space of the given environment. 

328 

329 Returns 

330 ------- 

331 tunables : TunableGroups 

332 A collection of covariant groups of tunable parameters. 

333 """ 

334 return self._tunable_params 

335 

336 @property 

337 def parameters(self) -> Dict[str, TunableValue]: 

338 """ 

339 Key/value pairs of all environment parameters (i.e., `const_args` and 

340 `tunable_params`). Note that before `.setup()` is called, all tunables will be 

341 set to None. 

342 

343 Returns 

344 ------- 

345 parameters : Dict[str, TunableValue] 

346 Key/value pairs of all environment parameters 

347 (i.e., `const_args` and `tunable_params`). 

348 """ 

349 return self._params 

350 

351 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

352 """ 

353 Set up a new benchmark environment, if necessary. This method must be 

354 idempotent, i.e., calling it several times in a row should be equivalent to a 

355 single call. 

356 

357 Parameters 

358 ---------- 

359 tunables : TunableGroups 

360 A collection of tunable parameters along with their values. 

361 global_config : dict 

362 Free-format dictionary of global parameters of the environment 

363 that are not used in the optimization process. 

364 

365 Returns 

366 ------- 

367 is_success : bool 

368 True if operation is successful, false otherwise. 

369 """ 

370 _LOG.info("Setup %s :: %s", self, tunables) 

371 assert isinstance(tunables, TunableGroups) 

372 

373 # Make sure we create a context before invoking setup/run/status/teardown 

374 assert self._in_context 

375 

376 # Assign new values to the environment's tunable parameters: 

377 groups = list(self._tunable_params.get_covariant_group_names()) 

378 self._tunable_params.assign(tunables.get_param_values(groups)) 

379 

380 # Write to the log whether the environment needs to be reset. 

381 # (Derived classes still have to check `self._tunable_params.is_updated()`). 

382 is_updated = self._tunable_params.is_updated() 

383 if _LOG.isEnabledFor(logging.DEBUG): 

384 _LOG.debug( 

385 "Env '%s': Tunable groups reset = %s :: %s", 

386 self, 

387 is_updated, 

388 { 

389 name: self._tunable_params.is_updated([name]) 

390 for name in self._tunable_params.get_covariant_group_names() 

391 }, 

392 ) 

393 else: 

394 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) 

395 

396 # Combine tunables, const_args, and global config into `self._params`: 

397 self._params = self._combine_tunables(tunables) 

398 merge_parameters(dest=self._params, source=global_config) 

399 

400 if _LOG.isEnabledFor(logging.DEBUG): 

401 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2)) 

402 

403 return True 

404 

405 def teardown(self) -> None: 

406 """ 

407 Tear down the benchmark environment. 

408 

409 This method must be idempotent, i.e., calling it several times in a row should 

410 be equivalent to a single call. 

411 """ 

412 _LOG.info("Teardown %s", self) 

413 # Make sure we create a context before invoking setup/run/status/teardown 

414 assert self._in_context 

415 self._is_ready = False 

416 

417 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

418 """ 

419 Execute the run script for this environment. 

420 

421 For instance, this may start a new experiment, download results, reconfigure 

422 the environment, etc. Details are configurable via the environment config. 

423 

424 Returns 

425 ------- 

426 (status, timestamp, output) : (Status, datetime.datetime, dict) 

427 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

428 with the results or None if the status is not COMPLETED. 

429 If run script is a benchmark, then the score is usually expected to 

430 be in the `score` field. 

431 """ 

432 # Make sure we create a context before invoking setup/run/status/teardown 

433 assert self._in_context 

434 (status, timestamp, _) = self.status() 

435 return (status, timestamp, None) 

436 

437 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

438 """ 

439 Check the status of the benchmark environment. 

440 

441 Returns 

442 ------- 

443 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

444 3-tuple of (benchmark status, timestamp, telemetry) values. 

445 `timestamp` is UTC time stamp of the status; it's current time by default. 

446 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

447 """ 

448 # Make sure we create a context before invoking setup/run/status/teardown 

449 assert self._in_context 

450 timestamp = datetime.now(UTC) 

451 if self._is_ready: 

452 return (Status.READY, timestamp, []) 

453 _LOG.warning("Environment not ready: %s", self) 

454 return (Status.PENDING, timestamp, [])