Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 89%

97 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Composite benchmark environment. 

7""" 

8 

9import logging 

10from datetime import datetime 

11 

12from types import TracebackType 

13from typing import Any, Dict, List, Optional, Tuple, Type 

14from typing_extensions import Literal 

15 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.environments.status import Status 

18from mlos_bench.environments.base_environment import Environment 

19from mlos_bench.tunables.tunable import TunableValue 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class CompositeEnv(Environment): 

26 """ 

27 Composite benchmark environment. 

28 """ 

29 

30 def __init__(self, 

31 *, 

32 name: str, 

33 config: dict, 

34 global_config: Optional[dict] = None, 

35 tunables: Optional[TunableGroups] = None, 

36 service: Optional[Service] = None): 

37 """ 

38 Create a new environment with a given config. 

39 

40 Parameters 

41 ---------- 

42 name: str 

43 Human-readable name of the environment. 

44 config : dict 

45 Free-format dictionary that contains the environment 

46 configuration. Must have a "children" section. 

47 global_config : dict 

48 Free-format dictionary of global parameters (e.g., security credentials) 

49 to be mixed in into the "const_args" section of the local config. 

50 tunables : TunableGroups 

51 A collection of groups of tunable parameters for *all* environments. 

52 service: Service 

53 An optional service object (e.g., providing methods to 

54 deploy or reboot a VM, etc.). 

55 """ 

56 super().__init__(name=name, config=config, global_config=global_config, 

57 tunables=tunables, service=service) 

58 

59 # By default, the Environment includes only the tunables explicitly specified 

60 # in the "tunable_params" section of the config. `CompositeEnv`, however, must 

61 # retain all tunables from its children environments plus the ones that come 

62 # from the "include_tunables". 

63 tunables = tunables.copy() if tunables else TunableGroups() 

64 

65 _LOG.debug("Build composite environment '%s' START: %s", self, tunables) 

66 self._children: List[Environment] = [] 

67 self._child_contexts: List[Environment] = [] 

68 

69 # To support trees of composite environments (e.g. for multiple VM experiments), 

70 # each CompositeEnv gets a copy of the original global config and adjusts it with 

71 # the `const_args` specific to it. 

72 global_config = (global_config or {}).copy() 

73 for (key, val) in self._const_args.items(): 

74 global_config.setdefault(key, val) 

75 

76 for child_config_file in config.get("include_children", []): 

77 for env in self._config_loader_service.load_environment_list( 

78 child_config_file, tunables, global_config, self._const_args, self._service): 

79 self._add_child(env, tunables) 

80 

81 for child_config in config.get("children", []): 

82 env = self._config_loader_service.build_environment( 

83 child_config, tunables, global_config, self._const_args, self._service) 

84 self._add_child(env, tunables) 

85 

86 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) 

87 

88 if not self._children: 

89 raise ValueError("At least one child environment must be present") 

90 

91 def __enter__(self) -> Environment: 

92 self._child_contexts = [env.__enter__() for env in self._children] 

93 return super().__enter__() 

94 

95 def __exit__(self, ex_type: Optional[Type[BaseException]], 

96 ex_val: Optional[BaseException], 

97 ex_tb: Optional[TracebackType]) -> Literal[False]: 

98 ex_throw = None 

99 for env in reversed(self._children): 

100 try: 

101 env.__exit__(ex_type, ex_val, ex_tb) 

102 # pylint: disable=broad-exception-caught 

103 except Exception as ex: 

104 _LOG.error("Exception while exiting child environment '%s': %s", env, ex) 

105 ex_throw = ex 

106 self._child_contexts = [] 

107 super().__exit__(ex_type, ex_val, ex_tb) 

108 if ex_throw: 

109 raise ex_throw 

110 return False 

111 

112 @property 

113 def children(self) -> List[Environment]: 

114 """ 

115 Return the list of child environments. 

116 """ 

117 return self._children 

118 

119 def pprint(self, indent: int = 4, level: int = 0) -> str: 

120 """ 

121 Pretty-print the environment and its children. 

122 

123 Parameters 

124 ---------- 

125 indent : int 

126 Number of spaces to indent the output at each level. Default is 4. 

127 level : int 

128 Current level of indentation. Default is 0. 

129 

130 Returns 

131 ------- 

132 pretty : str 

133 Pretty-printed environment configuration. 

134 """ 

135 return super().pprint(indent, level) + '\n' + '\n'.join( 

136 child.pprint(indent, level + 1) for child in self._children) 

137 

138 def _add_child(self, env: Environment, tunables: TunableGroups) -> None: 

139 """ 

140 Add a new child environment to the composite environment. 

141 This method is called from the constructor only. 

142 """ 

143 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params) 

144 self._children.append(env) 

145 self._tunable_params.merge(env.tunable_params) 

146 tunables.merge(env.tunable_params) 

147 

148 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: 

149 """ 

150 Set up the children environments. 

151 

152 Parameters 

153 ---------- 

154 tunables : TunableGroups 

155 A collection of tunable parameters along with their values. 

156 global_config : dict 

157 Free-format dictionary of global parameters of the environment 

158 that are not used in the optimization process. 

159 

160 Returns 

161 ------- 

162 is_success : bool 

163 True if all children setup() operations are successful, 

164 false otherwise. 

165 """ 

166 assert self._in_context 

167 self._is_ready = super().setup(tunables, global_config) and all( 

168 env_context.setup(tunables, global_config) for env_context in self._child_contexts) 

169 return self._is_ready 

170 

171 def teardown(self) -> None: 

172 """ 

173 Tear down the children environments. This method is idempotent, 

174 i.e., calling it several times is equivalent to a single call. 

175 The environments are being torn down in the reverse order. 

176 """ 

177 assert self._in_context 

178 for env_context in reversed(self._child_contexts): 

179 env_context.teardown() 

180 super().teardown() 

181 

182 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

183 """ 

184 Submit a new experiment to the environment. 

185 Return the result of the *last* child environment if successful, 

186 or the status of the last failed environment otherwise. 

187 

188 Returns 

189 ------- 

190 (status, timestamp, output) : (Status, datetime, dict) 

191 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

192 with the results or None if the status is not COMPLETED. 

193 If run script is a benchmark, then the score is usually expected to 

194 be in the `score` field. 

195 """ 

196 _LOG.info("Run: %s", self._children) 

197 (status, timestamp, metrics) = super().run() 

198 if not status.is_ready(): 

199 return (status, timestamp, metrics) 

200 

201 joint_metrics = {} 

202 for env_context in self._child_contexts: 

203 _LOG.debug("Child env. run: %s", env_context) 

204 (status, timestamp, metrics) = env_context.run() 

205 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics) 

206 if not status.is_good(): 

207 _LOG.info("Run failed: %s :: %s", self, status) 

208 return (status, timestamp, None) 

209 joint_metrics.update(metrics or {}) 

210 

211 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics) 

212 # Return the status and the timestamp of the last child environment. 

213 return (status, timestamp, joint_metrics) 

214 

215 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

216 """ 

217 Check the status of the benchmark environment. 

218 

219 Returns 

220 ------- 

221 (benchmark_status, timestamp, telemetry) : (Status, datetime, list) 

222 3-tuple of (benchmark status, timestamp, telemetry) values. 

223 `timestamp` is UTC time stamp of the status; it's current time by default. 

224 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

225 """ 

226 (status, timestamp, telemetry) = super().status() 

227 if not status.is_ready(): 

228 return (status, timestamp, telemetry) 

229 

230 joint_telemetry = [] 

231 final_status = None 

232 for env_context in self._child_contexts: 

233 (status, timestamp, telemetry) = env_context.status() 

234 _LOG.debug("Child env. status: %s :: %s", env_context, status) 

235 joint_telemetry.extend(telemetry) 

236 if not status.is_good() and final_status is None: 

237 final_status = status 

238 

239 final_status = final_status or status 

240 _LOG.info("Final status: %s :: %s", self, final_status) 

241 # Return the status and the timestamp of the last child environment or the first failed child environment. 

242 return (final_status, timestamp, joint_telemetry)