Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A hierarchy of benchmark environments.""" 

6 

7import abc 

8import json 

9import logging 

10from datetime import datetime 

11from types import TracebackType 

12from typing import ( 

13 TYPE_CHECKING, 

14 Any, 

15 Dict, 

16 Iterable, 

17 List, 

18 Literal, 

19 Optional, 

20 Sequence, 

21 Tuple, 

22 Type, 

23 Union, 

24) 

25 

26from pytz import UTC 

27 

28from mlos_bench.config.schemas import ConfigSchema 

29from mlos_bench.dict_templater import DictTemplater 

30from mlos_bench.environments.status import Status 

31from mlos_bench.services.base_service import Service 

32from mlos_bench.tunables.tunable import TunableValue 

33from mlos_bench.tunables.tunable_groups import TunableGroups 

34from mlos_bench.util import instantiate_from_config, merge_parameters 

35 

36if TYPE_CHECKING: 

37 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class Environment(metaclass=abc.ABCMeta): 

43 # pylint: disable=too-many-instance-attributes 

44 """An abstract base of all benchmark environments.""" 

45 

46 @classmethod 

47 def new( # pylint: disable=too-many-arguments 

48 cls, 

49 *, 

50 env_name: str, 

51 class_name: str, 

52 config: dict, 

53 global_config: Optional[dict] = None, 

54 tunables: Optional[TunableGroups] = None, 

55 service: Optional[Service] = None, 

56 ) -> "Environment": 

57 """ 

58 Factory method for a new environment with a given config. 

59 

60 Parameters 

61 ---------- 

62 env_name: str 

63 Human-readable name of the environment. 

64 class_name: str 

65 FQN of a Python class to instantiate, e.g., 

66 "mlos_bench.environments.remote.HostEnv". 

67 Must be derived from the `Environment` class. 

68 config : dict 

69 Free-format dictionary that contains the benchmark environment 

70 configuration. It will be passed as a constructor parameter of 

71 the class specified by `name`. 

72 global_config : dict 

73 Free-format dictionary of global parameters (e.g., security credentials) 

74 to be mixed in into the "const_args" section of the local config. 

75 tunables : TunableGroups 

76 A collection of groups of tunable parameters for all environments. 

77 service: Service 

78 An optional service object (e.g., providing methods to 

79 deploy or reboot a VM/Host, etc.). 

80 

81 Returns 

82 ------- 

83 env : Environment 

84 An instance of the `Environment` class initialized with `config`. 

85 """ 

86 assert issubclass(cls, Environment) 

87 return instantiate_from_config( 

88 cls, 

89 class_name, 

90 name=env_name, 

91 config=config, 

92 global_config=global_config, 

93 tunables=tunables, 

94 service=service, 

95 ) 

96 

97 def __init__( # pylint: disable=too-many-arguments 

98 self, 

99 *, 

100 name: str, 

101 config: dict, 

102 global_config: Optional[dict] = None, 

103 tunables: Optional[TunableGroups] = None, 

104 service: Optional[Service] = None, 

105 ): 

106 """ 

107 Create a new environment with a given config. 

108 

109 Parameters 

110 ---------- 

111 name: str 

112 Human-readable name of the environment. 

113 config : dict 

114 Free-format dictionary that contains the benchmark environment 

115 configuration. Each config must have at least the "tunable_params" 

116 and the "const_args" sections. 

117 global_config : dict 

118 Free-format dictionary of global parameters (e.g., security credentials) 

119 to be mixed in into the "const_args" section of the local config. 

120 tunables : TunableGroups 

121 A collection of groups of tunable parameters for all environments. 

122 service: Service 

123 An optional service object (e.g., providing methods to 

124 deploy or reboot a VM/Host, etc.). 

125 """ 

126 self._validate_json_config(config, name) 

127 self.name = name 

128 self.config = config 

129 self._service = service 

130 self._service_context: Optional[Service] = None 

131 self._is_ready = False 

132 self._in_context = False 

133 self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) 

134 

135 if _LOG.isEnabledFor(logging.DEBUG): 

136 _LOG.debug( 

137 "Environment: '%s' Service: %s", 

138 name, 

139 self._service.pprint() if self._service else None, 

140 ) 

141 

142 if tunables is None: 

143 _LOG.warning( 

144 ( 

145 "No tunables provided for %s. " 

146 "Tunable inheritance across composite environments may be broken." 

147 ), 

148 name, 

149 ) 

150 tunables = TunableGroups() 

151 

152 groups = self._expand_groups( 

153 config.get("tunable_params", []), 

154 (global_config or {}).get("tunable_params_map", {}), 

155 ) 

156 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) 

157 

158 self._tunable_params = tunables.subgroup(groups) 

159 

160 # If a parameter comes from the tunables, do not require it in the const_args or globals 

161 req_args = set(config.get("required_args", [])) - set( 

162 self._tunable_params.get_param_values().keys() 

163 ) 

164 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) 

165 self._const_args = self._expand_vars(self._const_args, global_config or {}) 

166 

167 self._params = self._combine_tunables(self._tunable_params) 

168 _LOG.debug("Parameters for '%s' :: %s", name, self._params) 

169 

170 if _LOG.isEnabledFor(logging.DEBUG): 

171 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) 

172 

173 def _validate_json_config(self, config: dict, name: str) -> None: 

174 """Reconstructs a basic json config that this class might have been instantiated 

175 from in order to validate configs provided outside the file loading 

176 mechanism. 

177 """ 

178 json_config: dict = { 

179 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

180 } 

181 if name: 

182 json_config["name"] = name 

183 if config: 

184 json_config["config"] = config 

185 ConfigSchema.ENVIRONMENT.validate(json_config) 

186 

187 @staticmethod 

188 def _expand_groups( 

189 groups: Iterable[str], 

190 groups_exp: Dict[str, Union[str, Sequence[str]]], 

191 ) -> List[str]: 

192 """ 

193 Expand `$tunable_group` into actual names of the tunable groups. 

194 

195 Parameters 

196 ---------- 

197 groups : List[str] 

198 Names of the groups of tunables, maybe with `$` prefix (subject to expansion). 

199 groups_exp : dict 

200 A dictionary that maps dollar variables for tunable groups to the lists 

201 of actual tunable groups IDs. 

202 

203 Returns 

204 ------- 

205 groups : List[str] 

206 A flat list of tunable groups IDs for the environment. 

207 """ 

208 res: List[str] = [] 

209 for grp in groups: 

210 if grp[:1] == "$": 

211 tunable_group_name = grp[1:] 

212 if tunable_group_name not in groups_exp: 

213 raise KeyError( 

214 ( 

215 f"Expected tunable group name ${tunable_group_name} " 

216 "undefined in {groups_exp}" 

217 ) 

218 ) 

219 add_groups = groups_exp[tunable_group_name] 

220 res += [add_groups] if isinstance(add_groups, str) else add_groups 

221 else: 

222 res.append(grp) 

223 return res 

224 

225 @staticmethod 

226 def _expand_vars( 

227 params: Dict[str, TunableValue], 

228 global_config: Dict[str, TunableValue], 

229 ) -> dict: 

230 """Expand `$var` into actual values of the variables.""" 

231 return DictTemplater(params).expand_vars(extra_source_dict=global_config) 

232 

233 @property 

234 def _config_loader_service(self) -> "SupportsConfigLoading": 

235 assert self._service is not None 

236 return self._service.config_loader_service 

237 

238 def __enter__(self) -> "Environment": 

239 """Enter the environment's benchmarking context.""" 

240 _LOG.debug("Environment START :: %s", self) 

241 assert not self._in_context 

242 if self._service: 

243 self._service_context = self._service.__enter__() 

244 self._in_context = True 

245 return self 

246 

247 def __exit__( 

248 self, 

249 ex_type: Optional[Type[BaseException]], 

250 ex_val: Optional[BaseException], 

251 ex_tb: Optional[TracebackType], 

252 ) -> Literal[False]: 

253 """Exit the context of the benchmarking environment.""" 

254 ex_throw = None 

255 if ex_val is None: 

256 _LOG.debug("Environment END :: %s", self) 

257 else: 

258 assert ex_type and ex_val 

259 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

260 assert self._in_context 

261 if self._service_context: 

262 try: 

263 self._service_context.__exit__(ex_type, ex_val, ex_tb) 

264 # pylint: disable=broad-exception-caught 

265 except Exception as ex: 

266 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) 

267 ex_throw = ex 

268 finally: 

269 self._service_context = None 

270 self._in_context = False 

271 if ex_throw: 

272 raise ex_throw 

273 return False # Do not suppress exceptions 

274 

275 def __str__(self) -> str: 

276 return self.name 

277 

278 def __repr__(self) -> str: 

279 return f"{self.__class__.__name__} :: '{self.name}'" 

280 

281 def pprint(self, indent: int = 4, level: int = 0) -> str: 

282 """ 

283 Pretty-print the environment configuration. For composite environments, print 

284 all children environments as well. 

285 

286 Parameters 

287 ---------- 

288 indent : int 

289 Number of spaces to indent the output. Default is 4. 

290 level : int 

291 Current level of indentation. Default is 0. 

292 

293 Returns 

294 ------- 

295 pretty : str 

296 Pretty-printed environment configuration. 

297 Default output is the same as `__repr__`. 

298 """ 

299 return f'{" " * indent * level}{repr(self)}' 

300 

301 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: 

302 """ 

303 Plug tunable values into the base config. If the tunable group is unknown, 

304 ignore it (it might belong to another environment). This method should never 

305 mutate the original config or the tunables. 

306 

307 Parameters 

308 ---------- 

309 tunables : TunableGroups 

310 A collection of groups of tunable parameters 

311 along with the parameters' values. 

312 

313 Returns 

314 ------- 

315 params : Dict[str, Union[int, float, str]] 

316 Free-format dictionary that contains the new environment configuration. 

317 """ 

318 return tunables.get_param_values( 

319 group_names=list(self._tunable_params.get_covariant_group_names()), 

320 into_params=self._const_args.copy(), 

321 ) 

322 

323 @property 

324 def tunable_params(self) -> TunableGroups: 

325 """ 

326 Get the configuration space of the given environment. 

327 

328 Returns 

329 ------- 

330 tunables : TunableGroups 

331 A collection of covariant groups of tunable parameters. 

332 """ 

333 return self._tunable_params 

334 

335 @property 

336 def parameters(self) -> Dict[str, TunableValue]: 

337 """ 

338 Key/value pairs of all environment parameters (i.e., `const_args` and 

339 `tunable_params`). Note that before `.setup()` is called, all tunables will be 

340 set to None. 

341 

342 Returns 

343 ------- 

344 parameters : Dict[str, TunableValue] 

345 Key/value pairs of all environment parameters 

346 (i.e., `const_args` and `tunable_params`). 

347 """ 

348 return self._params 

349 

350 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

351 """ 

352 Set up a new benchmark environment, if necessary. This method must be 

353 idempotent, i.e., calling it several times in a row should be equivalent to a 

354 single call. 

355 

356 Parameters 

357 ---------- 

358 tunables : TunableGroups 

359 A collection of tunable parameters along with their values. 

360 global_config : dict 

361 Free-format dictionary of global parameters of the environment 

362 that are not used in the optimization process. 

363 

364 Returns 

365 ------- 

366 is_success : bool 

367 True if operation is successful, false otherwise. 

368 """ 

369 _LOG.info("Setup %s :: %s", self, tunables) 

370 assert isinstance(tunables, TunableGroups) 

371 

372 # Make sure we create a context before invoking setup/run/status/teardown 

373 assert self._in_context 

374 

375 # Assign new values to the environment's tunable parameters: 

376 groups = list(self._tunable_params.get_covariant_group_names()) 

377 self._tunable_params.assign(tunables.get_param_values(groups)) 

378 

379 # Write to the log whether the environment needs to be reset. 

380 # (Derived classes still have to check `self._tunable_params.is_updated()`). 

381 is_updated = self._tunable_params.is_updated() 

382 if _LOG.isEnabledFor(logging.DEBUG): 

383 _LOG.debug( 

384 "Env '%s': Tunable groups reset = %s :: %s", 

385 self, 

386 is_updated, 

387 { 

388 name: self._tunable_params.is_updated([name]) 

389 for name in self._tunable_params.get_covariant_group_names() 

390 }, 

391 ) 

392 else: 

393 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) 

394 

395 # Combine tunables, const_args, and global config into `self._params`: 

396 self._params = self._combine_tunables(tunables) 

397 merge_parameters(dest=self._params, source=global_config) 

398 

399 if _LOG.isEnabledFor(logging.DEBUG): 

400 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2)) 

401 

402 return True 

403 

404 def teardown(self) -> None: 

405 """ 

406 Tear down the benchmark environment. 

407 

408 This method must be idempotent, i.e., calling it several times in a row should 

409 be equivalent to a single call. 

410 """ 

411 _LOG.info("Teardown %s", self) 

412 # Make sure we create a context before invoking setup/run/status/teardown 

413 assert self._in_context 

414 self._is_ready = False 

415 

416 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

417 """ 

418 Execute the run script for this environment. 

419 

420 For instance, this may start a new experiment, download results, reconfigure 

421 the environment, etc. Details are configurable via the environment config. 

422 

423 Returns 

424 ------- 

425 (status, timestamp, output) : (Status, datetime.datetime, dict) 

426 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

427 with the results or None if the status is not COMPLETED. 

428 If run script is a benchmark, then the score is usually expected to 

429 be in the `score` field. 

430 """ 

431 # Make sure we create a context before invoking setup/run/status/teardown 

432 assert self._in_context 

433 (status, timestamp, _) = self.status() 

434 return (status, timestamp, None) 

435 

436 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

437 """ 

438 Check the status of the benchmark environment. 

439 

440 Returns 

441 ------- 

442 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

443 3-tuple of (benchmark status, timestamp, telemetry) values. 

444 `timestamp` is UTC time stamp of the status; it's current time by default. 

445 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

446 """ 

447 # Make sure we create a context before invoking setup/run/status/teardown 

448 assert self._in_context 

449 timestamp = datetime.now(UTC) 

450 if self._is_ready: 

451 return (Status.READY, timestamp, []) 

452 _LOG.warning("Environment not ready: %s", self) 

453 return (Status.PENDING, timestamp, [])