Coverage for mlos_bench/mlos_bench/schedulers/trial_runner.py: 94%
85 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Simple class to run an individual Trial on a given Environment."""
7import logging
8from datetime import datetime
9from types import TracebackType
10from typing import Any, Literal
12from pytz import UTC
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.event_loop_context import EventLoopContext
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.config_persistence import ConfigPersistenceService
19from mlos_bench.services.local.local_exec import LocalExecService
20from mlos_bench.services.types import SupportsConfigLoading
21from mlos_bench.storage.base_storage import Storage
22from mlos_bench.tunables.tunable_groups import TunableGroups
23from mlos_bench.tunables.tunable_types import TunableValue
25_LOG = logging.getLogger(__name__)
28class TrialRunner:
29 """
30 Simple class to help run an individual Trial on an environment.
32 TrialRunner manages the lifecycle of a single trial, including setup, run, teardown,
33 and async status polling via EventLoopContext background threads.
35 Multiple TrialRunners can be used in a multi-processing pool to run multiple trials
36 in parallel, for instance.
37 """
39 @classmethod
40 def create_from_json(
41 cls,
42 *,
43 config_loader: Service,
44 env_json: str,
45 svcs_json: str | list[str] | None = None,
46 num_trial_runners: int = 1,
47 tunable_groups: TunableGroups | None = None,
48 global_config: dict[str, Any] | None = None,
49 ) -> list["TrialRunner"]:
50 # pylint: disable=too-many-arguments
51 """
52 Create a list of TrialRunner instances, and their associated Environments and
53 Services, from JSON configurations.
55 Since each TrialRunner instance is independent, they can be run in parallel,
56 and hence must each get their own copy of the Environment and Services to
57 operate on.
59 The global_config is shared across all TrialRunners, but each copy gets its
60 own unique trial_runner_id.
62 Parameters
63 ----------
64 config_loader : Service
65 A service instance capable of loading configuration (i.e., SupportsConfigLoading).
66 env_json : str
67 JSON file or string representing the environment configuration.
68 svcs_json : str | list[str] | None
69 JSON file(s) or string(s) representing the Services configuration.
70 num_trial_runners : int
71 Number of TrialRunner instances to create. Default is 1.
72 tunable_groups : TunableGroups | None
73 TunableGroups instance to use as the parent Tunables for the
74 environment. Default is None.
75 global_config : dict[str, Any] | None
76 Global configuration parameters. Default is None.
78 Returns
79 -------
80 list[TrialRunner]
81 A list of TrialRunner instances created from the provided configuration.
82 """
83 assert isinstance(config_loader, SupportsConfigLoading)
84 svcs_json = svcs_json or []
85 tunable_groups = tunable_groups or TunableGroups()
86 global_config = global_config or {}
87 trial_runners: list[TrialRunner] = []
88 for trial_runner_id in range(1, num_trial_runners + 1): # use 1-based indexing
89 # Make a fresh Environment and Services copy for each TrialRunner.
90 # Give each global_config copy its own unique trial_runner_id.
91 # This is important in case multiple TrialRunners are running in parallel.
92 global_config_copy = global_config.copy()
93 global_config_copy["trial_runner_id"] = trial_runner_id
94 # Each Environment's parent service starts with at least a
95 # LocalExecService in addition to the ConfigLoader.
96 parent_service: Service = ConfigPersistenceService(
97 config={"config_path": config_loader.get_config_paths()},
98 global_config=global_config_copy,
99 )
100 parent_service = LocalExecService(parent=parent_service)
101 parent_service = config_loader.load_services(
102 svcs_json,
103 global_config_copy,
104 parent_service,
105 )
106 env = config_loader.load_environment(
107 env_json,
108 tunable_groups.copy(),
109 global_config_copy,
110 service=parent_service,
111 )
112 trial_runners.append(TrialRunner(trial_runner_id, env))
113 return trial_runners
115 def __init__(self, trial_runner_id: int, env: Environment) -> None:
116 self._trial_runner_id = trial_runner_id
117 self._env = env
118 assert self._env.parameters["trial_runner_id"] == self._trial_runner_id
119 self._in_context = False
120 self._is_running = False
121 self._event_loop_context = EventLoopContext()
123 def __repr__(self) -> str:
124 return (
125 f"TrialRunner({self.trial_runner_id}, {repr(self.environment)}"
126 f"""[trial_runner_id={self.environment.parameters.get("trial_runner_id")}])"""
127 )
129 def __str__(self) -> str:
130 return f"TrialRunner({self.trial_runner_id}, {str(self.environment)})"
132 @property
133 def trial_runner_id(self) -> int:
134 """Get the TrialRunner's id."""
135 return self._trial_runner_id
137 @property
138 def environment(self) -> Environment:
139 """Get the Environment."""
140 return self._env
142 def __enter__(self) -> "TrialRunner":
143 assert not self._in_context
144 _LOG.debug("TrialRunner START :: %s", self)
145 # TODO: self._event_loop_context.enter()
146 self._env.__enter__()
147 self._in_context = True
148 return self
150 def __exit__(
151 self,
152 ex_type: type[BaseException] | None,
153 ex_val: BaseException | None,
154 ex_tb: TracebackType | None,
155 ) -> Literal[False]:
156 assert self._in_context
157 _LOG.debug("TrialRunner END :: %s", self)
158 self._env.__exit__(ex_type, ex_val, ex_tb)
159 # TODO: self._event_loop_context.exit()
160 self._in_context = False
161 return False # Do not suppress exceptions
163 @property
164 def is_running(self) -> bool:
165 """Get the running state of the current TrialRunner."""
166 return self._is_running
168 def run_trial(
169 self,
170 trial: Storage.Trial,
171 global_config: dict[str, Any] | None = None,
172 ) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
173 """
174 Run a single trial on this TrialRunner's Environment and stores the results in
175 the backend Trial Storage.
177 Parameters
178 ----------
179 trial : Storage.Trial
180 A Storage class based Trial used to persist the experiment trial data.
181 global_config : dict
182 Global configuration parameters.
184 Returns
185 -------
186 (trial_status, trial_score) : (Status, dict[str, float] | None)
187 Status and results of the trial.
188 """
189 assert self._in_context
191 assert not self._is_running
192 self._is_running = True
194 assert trial.trial_runner_id == self.trial_runner_id, (
195 f"TrialRunner {self} should not run trial {trial} "
196 f"with different trial_runner_id {trial.trial_runner_id}."
197 )
199 if not self.environment.setup(trial.tunables, trial.config(global_config)):
200 _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables)
201 # FIXME: Use the actual timestamp from the environment.
202 (status, timestamp, results) = (Status.FAILED, datetime.now(UTC), None)
203 _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, status)
204 trial.update(status, timestamp)
205 return (status, timestamp, results)
207 # TODO: start background status polling of the environments in the event loop.
209 # Block and wait for the final result.
210 (status, timestamp, results) = self.environment.run()
211 _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results)
213 # In async mode (TODO), poll the environment for status and telemetry
214 # and update the storage with the intermediate results.
215 (_status, _timestamp, telemetry) = self.environment.status()
217 # Use the status and timestamp from `.run()` as it is the final status of the experiment.
218 # TODO: Use the `.status()` output in async mode.
219 trial.update_telemetry(status, timestamp, telemetry)
221 trial.update(status, timestamp, results)
222 _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results)
224 self._is_running = False
226 return (status, timestamp, results)
228 def teardown(self) -> None:
229 """
230 Tear down the Environment.
232 Call it after the completion of one (or more) `.run()` in the TrialRunner
233 context.
234 """
235 assert self._in_context
236 self._env.teardown()