Coverage for mlos_bench/mlos_bench/schedulers/trial_runner.py: 94%

85 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Simple class to run an individual Trial on a given Environment.""" 

6 

7import logging 

8from datetime import datetime 

9from types import TracebackType 

10from typing import Any, Literal 

11 

12from pytz import UTC 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.environments.status import Status 

16from mlos_bench.event_loop_context import EventLoopContext 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.config_persistence import ConfigPersistenceService 

19from mlos_bench.services.local.local_exec import LocalExecService 

20from mlos_bench.services.types import SupportsConfigLoading 

21from mlos_bench.storage.base_storage import Storage 

22from mlos_bench.tunables.tunable_groups import TunableGroups 

23from mlos_bench.tunables.tunable_types import TunableValue 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class TrialRunner: 

29 """ 

30 Simple class to help run an individual Trial on an environment. 

31 

32 TrialRunner manages the lifecycle of a single trial, including setup, run, teardown, 

33 and async status polling via EventLoopContext background threads. 

34 

35 Multiple TrialRunners can be used in a multi-processing pool to run multiple trials 

36 in parallel, for instance. 

37 """ 

38 

39 @classmethod 

40 def create_from_json( 

41 cls, 

42 *, 

43 config_loader: Service, 

44 env_json: str, 

45 svcs_json: str | list[str] | None = None, 

46 num_trial_runners: int = 1, 

47 tunable_groups: TunableGroups | None = None, 

48 global_config: dict[str, Any] | None = None, 

49 ) -> list["TrialRunner"]: 

50 # pylint: disable=too-many-arguments 

51 """ 

52 Create a list of TrialRunner instances, and their associated Environments and 

53 Services, from JSON configurations. 

54 

55 Since each TrialRunner instance is independent, they can be run in parallel, 

56 and hence must each get their own copy of the Environment and Services to 

57 operate on. 

58 

59 The global_config is shared across all TrialRunners, but each copy gets its 

60 own unique trial_runner_id. 

61 

62 Parameters 

63 ---------- 

64 config_loader : Service 

65 A service instance capable of loading configuration (i.e., SupportsConfigLoading). 

66 env_json : str 

67 JSON file or string representing the environment configuration. 

68 svcs_json : str | list[str] | None 

69 JSON file(s) or string(s) representing the Services configuration. 

70 num_trial_runners : int 

71 Number of TrialRunner instances to create. Default is 1. 

72 tunable_groups : TunableGroups | None 

73 TunableGroups instance to use as the parent Tunables for the 

74 environment. Default is None. 

75 global_config : dict[str, Any] | None 

76 Global configuration parameters. Default is None. 

77 

78 Returns 

79 ------- 

80 list[TrialRunner] 

81 A list of TrialRunner instances created from the provided configuration. 

82 """ 

83 assert isinstance(config_loader, SupportsConfigLoading) 

84 svcs_json = svcs_json or [] 

85 tunable_groups = tunable_groups or TunableGroups() 

86 global_config = global_config or {} 

87 trial_runners: list[TrialRunner] = [] 

88 for trial_runner_id in range(1, num_trial_runners + 1): # use 1-based indexing 

89 # Make a fresh Environment and Services copy for each TrialRunner. 

90 # Give each global_config copy its own unique trial_runner_id. 

91 # This is important in case multiple TrialRunners are running in parallel. 

92 global_config_copy = global_config.copy() 

93 global_config_copy["trial_runner_id"] = trial_runner_id 

94 # Each Environment's parent service starts with at least a 

95 # LocalExecService in addition to the ConfigLoader. 

96 parent_service: Service = ConfigPersistenceService( 

97 config={"config_path": config_loader.get_config_paths()}, 

98 global_config=global_config_copy, 

99 ) 

100 parent_service = LocalExecService(parent=parent_service) 

101 parent_service = config_loader.load_services( 

102 svcs_json, 

103 global_config_copy, 

104 parent_service, 

105 ) 

106 env = config_loader.load_environment( 

107 env_json, 

108 tunable_groups.copy(), 

109 global_config_copy, 

110 service=parent_service, 

111 ) 

112 trial_runners.append(TrialRunner(trial_runner_id, env)) 

113 return trial_runners 

114 

115 def __init__(self, trial_runner_id: int, env: Environment) -> None: 

116 self._trial_runner_id = trial_runner_id 

117 self._env = env 

118 assert self._env.parameters["trial_runner_id"] == self._trial_runner_id 

119 self._in_context = False 

120 self._is_running = False 

121 self._event_loop_context = EventLoopContext() 

122 

123 def __repr__(self) -> str: 

124 return ( 

125 f"TrialRunner({self.trial_runner_id}, {repr(self.environment)}" 

126 f"""[trial_runner_id={self.environment.parameters.get("trial_runner_id")}])""" 

127 ) 

128 

129 def __str__(self) -> str: 

130 return f"TrialRunner({self.trial_runner_id}, {str(self.environment)})" 

131 

132 @property 

133 def trial_runner_id(self) -> int: 

134 """Get the TrialRunner's id.""" 

135 return self._trial_runner_id 

136 

137 @property 

138 def environment(self) -> Environment: 

139 """Get the Environment.""" 

140 return self._env 

141 

142 def __enter__(self) -> "TrialRunner": 

143 assert not self._in_context 

144 _LOG.debug("TrialRunner START :: %s", self) 

145 # TODO: self._event_loop_context.enter() 

146 self._env.__enter__() 

147 self._in_context = True 

148 return self 

149 

150 def __exit__( 

151 self, 

152 ex_type: type[BaseException] | None, 

153 ex_val: BaseException | None, 

154 ex_tb: TracebackType | None, 

155 ) -> Literal[False]: 

156 assert self._in_context 

157 _LOG.debug("TrialRunner END :: %s", self) 

158 self._env.__exit__(ex_type, ex_val, ex_tb) 

159 # TODO: self._event_loop_context.exit() 

160 self._in_context = False 

161 return False # Do not suppress exceptions 

162 

163 @property 

164 def is_running(self) -> bool: 

165 """Get the running state of the current TrialRunner.""" 

166 return self._is_running 

167 

168 def run_trial( 

169 self, 

170 trial: Storage.Trial, 

171 global_config: dict[str, Any] | None = None, 

172 ) -> tuple[Status, datetime, dict[str, TunableValue] | None]: 

173 """ 

174 Run a single trial on this TrialRunner's Environment and stores the results in 

175 the backend Trial Storage. 

176 

177 Parameters 

178 ---------- 

179 trial : Storage.Trial 

180 A Storage class based Trial used to persist the experiment trial data. 

181 global_config : dict 

182 Global configuration parameters. 

183 

184 Returns 

185 ------- 

186 (trial_status, trial_score) : (Status, dict[str, float] | None) 

187 Status and results of the trial. 

188 """ 

189 assert self._in_context 

190 

191 assert not self._is_running 

192 self._is_running = True 

193 

194 assert trial.trial_runner_id == self.trial_runner_id, ( 

195 f"TrialRunner {self} should not run trial {trial} " 

196 f"with different trial_runner_id {trial.trial_runner_id}." 

197 ) 

198 

199 if not self.environment.setup(trial.tunables, trial.config(global_config)): 

200 _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables) 

201 # FIXME: Use the actual timestamp from the environment. 

202 (status, timestamp, results) = (Status.FAILED, datetime.now(UTC), None) 

203 _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, status) 

204 trial.update(status, timestamp) 

205 return (status, timestamp, results) 

206 

207 # TODO: start background status polling of the environments in the event loop. 

208 

209 # Block and wait for the final result. 

210 (status, timestamp, results) = self.environment.run() 

211 _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results) 

212 

213 # In async mode (TODO), poll the environment for status and telemetry 

214 # and update the storage with the intermediate results. 

215 (_status, _timestamp, telemetry) = self.environment.status() 

216 

217 # Use the status and timestamp from `.run()` as it is the final status of the experiment. 

218 # TODO: Use the `.status()` output in async mode. 

219 trial.update_telemetry(status, timestamp, telemetry) 

220 

221 trial.update(status, timestamp, results) 

222 _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results) 

223 

224 self._is_running = False 

225 

226 return (status, timestamp, results) 

227 

228 def teardown(self) -> None: 

229 """ 

230 Tear down the Environment. 

231 

232 Call it after the completion of one (or more) `.run()` in the TrialRunner 

233 context. 

234 """ 

235 assert self._in_context 

236 self._env.teardown()