Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 89%
101 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Saving and restoring the benchmark data using SQLAlchemy.
7"""
9import logging
10import hashlib
11from datetime import datetime
12from typing import Optional, Tuple, List, Dict, Iterator, Any
14from pytz import UTC
16from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select
18from mlos_bench.environments.status import Status
19from mlos_bench.tunables.tunable_groups import TunableGroups
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.schema import DbSchema
22from mlos_bench.storage.sql.trial import Trial
23from mlos_bench.util import nullable, utcify_timestamp
25_LOG = logging.getLogger(__name__)
28class Experiment(Storage.Experiment):
29 """
30 Logic for retrieving and storing the results of a single experiment.
31 """
33 def __init__(self, *,
34 engine: Engine,
35 schema: DbSchema,
36 tunables: TunableGroups,
37 experiment_id: str,
38 trial_id: int,
39 root_env_config: str,
40 description: str,
41 opt_target: str,
42 opt_direction: Optional[str]):
43 super().__init__(
44 tunables=tunables,
45 experiment_id=experiment_id,
46 trial_id=trial_id,
47 root_env_config=root_env_config,
48 description=description,
49 opt_target=opt_target,
50 opt_direction=opt_direction)
51 self._engine = engine
52 self._schema = schema
54 def _setup(self) -> None:
55 super()._setup()
56 with self._engine.begin() as conn:
57 # Get git info and the last trial ID for the experiment.
58 # pylint: disable=not-callable
59 exp_info = conn.execute(
60 self._schema.experiment.select().with_only_columns(
61 self._schema.experiment.c.git_repo,
62 self._schema.experiment.c.git_commit,
63 self._schema.experiment.c.root_env_config,
64 func.max(self._schema.trial.c.trial_id).label("trial_id"),
65 ).join(
66 self._schema.trial,
67 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
68 isouter=True
69 ).where(
70 self._schema.experiment.c.exp_id == self._experiment_id,
71 ).group_by(
72 self._schema.experiment.c.git_repo,
73 self._schema.experiment.c.git_commit,
74 self._schema.experiment.c.root_env_config,
75 )
76 ).fetchone()
77 if exp_info is None:
78 _LOG.info("Start new experiment: %s", self._experiment_id)
79 # It's a new experiment: create a record for it in the database.
80 conn.execute(self._schema.experiment.insert().values(
81 exp_id=self._experiment_id,
82 description=self._description,
83 git_repo=self._git_repo,
84 git_commit=self._git_commit,
85 root_env_config=self._root_env_config,
86 ))
87 # TODO: Expand for multiple objectives.
88 conn.execute(self._schema.objectives.insert().values(
89 exp_id=self._experiment_id,
90 optimization_target=self._opt_target,
91 optimization_direction=self._opt_direction,
92 ))
93 else:
94 if exp_info.trial_id is not None:
95 self._trial_id = exp_info.trial_id + 1
96 _LOG.info("Continue experiment: %s last trial: %s resume from: %d",
97 self._experiment_id, exp_info.trial_id, self._trial_id)
98 # TODO: Sanity check that certain critical configs (e.g.,
99 # objectives) haven't changed to be incompatible such that a new
100 # experiment should be started (possibly by prewarming with the
101 # previous one).
102 if exp_info.git_commit != self._git_commit:
103 _LOG.warning("Experiment %s git expected: %s %s",
104 self, exp_info.git_repo, exp_info.git_commit)
106 def merge(self, experiment_ids: List[str]) -> None:
107 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
108 raise NotImplementedError("TODO")
110 def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
111 with self._engine.connect() as conn:
112 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id)
114 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
115 with self._engine.connect() as conn:
116 cur_telemetry = conn.execute(
117 self._schema.trial_telemetry.select().where(
118 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
119 self._schema.trial_telemetry.c.trial_id == trial_id
120 ).order_by(
121 self._schema.trial_telemetry.c.ts,
122 self._schema.trial_telemetry.c.metric_id,
123 )
124 )
125 # Not all storage backends store the original zone info.
126 # We try to ensure data is entered in UTC and augment it on return again here.
127 return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
128 for row in cur_telemetry.fetchall()]
130 def load(self, last_trial_id: int = -1,
131 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
133 with self._engine.connect() as conn:
134 cur_trials = conn.execute(
135 self._schema.trial.select().with_only_columns(
136 self._schema.trial.c.trial_id,
137 self._schema.trial.c.config_id,
138 self._schema.trial.c.status,
139 ).where(
140 self._schema.trial.c.exp_id == self._experiment_id,
141 self._schema.trial.c.trial_id > last_trial_id,
142 self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']),
143 ).order_by(
144 self._schema.trial.c.trial_id.asc(),
145 )
146 )
148 trial_ids: List[int] = []
149 configs: List[Dict[str, Any]] = []
150 scores: List[Optional[Dict[str, Any]]] = []
151 status: List[Status] = []
153 for trial in cur_trials.fetchall():
154 stat = Status[trial.status]
155 status.append(stat)
156 trial_ids.append(trial.trial_id)
157 configs.append(self._get_key_val(
158 conn, self._schema.config_param, "param", config_id=trial.config_id))
159 if stat.is_succeeded():
160 scores.append(self._get_key_val(
161 conn, self._schema.trial_result, "metric",
162 exp_id=self._experiment_id, trial_id=trial.trial_id))
163 else:
164 scores.append(None)
166 return (trial_ids, configs, scores, status)
168 @staticmethod
169 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
170 """
171 Helper method to retrieve key-value pairs from the database.
172 (E.g., configurations, results, and telemetry).
173 """
174 cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
175 select(
176 column(f"{field}_id"),
177 column(f"{field}_value"),
178 ).select_from(table).where(
179 *[column(key) == val for (key, val) in kwargs.items()]
180 )
181 )
182 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts.
183 return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access
185 @staticmethod
186 def _save_params(conn: Connection, table: Table,
187 params: Dict[str, Any], **kwargs: Any) -> None:
188 if not params:
189 return
190 conn.execute(table.insert(), [
191 {
192 **kwargs,
193 "param_id": key,
194 "param_value": nullable(str, val)
195 }
196 for (key, val) in params.items()
197 ])
199 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
200 timestamp = utcify_timestamp(timestamp, origin="local")
201 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
202 if running:
203 pending_status = ['PENDING', 'READY', 'RUNNING']
204 else:
205 pending_status = ['PENDING']
206 with self._engine.connect() as conn:
207 cur_trials = conn.execute(self._schema.trial.select().where(
208 self._schema.trial.c.exp_id == self._experiment_id,
209 (self._schema.trial.c.ts_start.is_(None) |
210 (self._schema.trial.c.ts_start <= timestamp)),
211 self._schema.trial.c.ts_end.is_(None),
212 self._schema.trial.c.status.in_(pending_status),
213 ))
214 for trial in cur_trials.fetchall():
215 tunables = self._get_key_val(
216 conn, self._schema.config_param, "param",
217 config_id=trial.config_id)
218 config = self._get_key_val(
219 conn, self._schema.trial_param, "param",
220 exp_id=self._experiment_id, trial_id=trial.trial_id)
221 yield Trial(
222 engine=self._engine,
223 schema=self._schema,
224 # Reset .is_updated flag after the assignment:
225 tunables=self._tunables.copy().assign(tunables).reset(),
226 experiment_id=self._experiment_id,
227 trial_id=trial.trial_id,
228 config_id=trial.config_id,
229 opt_target=self._opt_target,
230 opt_direction=self._opt_direction,
231 config=config,
232 )
234 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
235 """
236 Get the config ID for the given tunables. If the config does not exist,
237 create a new record for it.
238 """
239 config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest()
240 cur_config = conn.execute(self._schema.config.select().where(
241 self._schema.config.c.config_hash == config_hash
242 )).fetchone()
243 if cur_config is not None:
244 return int(cur_config.config_id) # mypy doesn't know it's always int
245 # Config not found, create a new one:
246 config_id: int = conn.execute(self._schema.config.insert().values(
247 config_hash=config_hash)).inserted_primary_key[0]
248 self._save_params(
249 conn, self._schema.config_param,
250 {tunable.name: tunable.value for (tunable, _group) in tunables},
251 config_id=config_id)
252 return config_id
254 def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None,
255 config: Optional[Dict[str, Any]] = None) -> Storage.Trial:
256 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local")
257 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
258 with self._engine.begin() as conn:
259 try:
260 config_id = self._get_config_id(conn, tunables)
261 conn.execute(self._schema.trial.insert().values(
262 exp_id=self._experiment_id,
263 trial_id=self._trial_id,
264 config_id=config_id,
265 ts_start=ts_start,
266 status='PENDING',
267 ))
269 # Note: config here is the framework config, not the target
270 # environment config (i.e., tunables).
271 if config is not None:
272 self._save_params(
273 conn, self._schema.trial_param, config,
274 exp_id=self._experiment_id, trial_id=self._trial_id)
276 trial = Trial(
277 engine=self._engine,
278 schema=self._schema,
279 tunables=tunables,
280 experiment_id=self._experiment_id,
281 trial_id=self._trial_id,
282 config_id=config_id,
283 opt_target=self._opt_target,
284 opt_direction=self._opt_direction,
285 config=config,
286 )
287 self._trial_id += 1
288 return trial
289 except Exception:
290 conn.rollback()
291 raise