Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%
109 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-14 00:55 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Test fixtures for mlos_bench storage."""
7import json
8import os
9import tempfile
10from collections.abc import Generator
11from contextlib import contextmanager
12from importlib.resources import files
13from random import seed as rand_seed
15import pytest
16from fasteners import InterProcessLock
17from pytest_docker.plugin import Services as DockerServices
18from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture
20from mlos_bench.optimizers.mock_optimizer import MockOptimizer
21from mlos_bench.schedulers.sync_scheduler import SyncScheduler
22from mlos_bench.schedulers.trial_runner import TrialRunner
23from mlos_bench.services.config_persistence import ConfigPersistenceService
24from mlos_bench.storage.base_experiment_data import ExperimentData
25from mlos_bench.storage.sql.storage import SqlStorage
26from mlos_bench.storage.storage_factory import from_config
27from mlos_bench.tests import DOCKER, SEED, wait_docker_service_healthy
28from mlos_bench.tests.storage import (
29 CONFIG_TRIAL_REPEAT_COUNT,
30 MAX_TRIALS,
31 TRIAL_RUNNER_COUNT,
32)
33from mlos_bench.tests.storage.sql import (
34 MYSQL_TEST_SERVER_NAME,
35 PGSQL_TEST_SERVER_NAME,
36 SqlTestServerInfo,
37)
38from mlos_bench.tunables.tunable_groups import TunableGroups
39from mlos_bench.util import path_join
41# pylint: disable=redefined-outer-name
43# Try to test multiple DBMS engines.
44DOCKER_DBMS_FIXTURES = []
45if DOCKER:
46 DOCKER_DBMS_FIXTURES = [
47 lazy_fixture("mysql_storage"),
48 lazy_fixture("postgres_storage"),
49 ]
51PERSISTENT_SQL_STORAGE_FIXTURES = [lazy_fixture("sqlite_storage")]
52if DOCKER:
53 PERSISTENT_SQL_STORAGE_FIXTURES.extend(DOCKER_DBMS_FIXTURES)
56@pytest.fixture(scope="session")
57def mysql_storage_info(
58 docker_hostname: str,
59 docker_compose_project_name: str,
60 locked_docker_services: DockerServices,
61) -> SqlTestServerInfo:
62 """Fixture for getting mysql storage connection info."""
63 storage_info = SqlTestServerInfo(
64 compose_project_name=docker_compose_project_name,
65 service_name=MYSQL_TEST_SERVER_NAME,
66 hostname=docker_hostname,
67 )
68 wait_docker_service_healthy(
69 locked_docker_services,
70 storage_info.compose_project_name,
71 storage_info.service_name,
72 )
74 return storage_info
77@pytest.fixture(scope="session")
78def postgres_storage_info(
79 docker_hostname: str,
80 docker_compose_project_name: str,
81 locked_docker_services: DockerServices,
82) -> SqlTestServerInfo:
83 """Fixture for getting postgres storage connection info."""
84 storage_info = SqlTestServerInfo(
85 compose_project_name=docker_compose_project_name,
86 service_name=PGSQL_TEST_SERVER_NAME,
87 hostname=docker_hostname,
88 )
89 wait_docker_service_healthy(
90 locked_docker_services,
91 storage_info.compose_project_name,
92 storage_info.service_name,
93 )
94 return storage_info
97@contextmanager
98def _create_storage_from_test_server_info(
99 config_file: str,
100 test_server_info: SqlTestServerInfo,
101 shared_temp_dir: str,
102 short_testrun_uid: str,
103) -> Generator[SqlStorage]:
104 """
105 Creates a SqlStorage instance from the given test server info.
107 Notes
108 -----
109 Resets the schema as a cleanup operation on return from the function scope
110 fixture so each test gets a fresh storage instance.
111 Uses a file lock to ensure that only one test can access the storage at a time.
113 Yields
114 ------
115 SqlStorage
116 """
117 sql_storage_name = test_server_info.service_name
118 with InterProcessLock(
119 path_join(shared_temp_dir, f"{sql_storage_name}-{short_testrun_uid}.lock")
120 ):
121 global_config = {
122 "host": test_server_info.hostname,
123 "port": test_server_info.get_port() or 0,
124 "database": test_server_info.database,
125 "username": test_server_info.username,
126 "password": test_server_info.password,
127 "lazy_schema_create": True,
128 }
129 storage = from_config(
130 config_file,
131 global_configs=[json.dumps(global_config)],
132 )
133 assert isinstance(storage, SqlStorage)
134 try:
135 yield storage
136 finally:
137 # Cleanup the storage on return
138 storage._reset_schema(force=True) # pylint: disable=protected-access
141@pytest.fixture(scope="function")
142def mysql_storage(
143 mysql_storage_info: SqlTestServerInfo,
144 shared_temp_dir: str,
145 short_testrun_uid: str,
146) -> Generator[SqlStorage]:
147 """
148 Fixture of a MySQL backed SqlStorage engine.
150 See Also
151 --------
152 _create_storage_from_test_server_info
153 """
154 with _create_storage_from_test_server_info(
155 path_join(str(files("mlos_bench.config")), "storage", "mysql.jsonc"),
156 mysql_storage_info,
157 shared_temp_dir,
158 short_testrun_uid,
159 ) as storage:
160 yield storage
163@pytest.fixture(scope="function")
164def postgres_storage(
165 postgres_storage_info: SqlTestServerInfo,
166 shared_temp_dir: str,
167 short_testrun_uid: str,
168) -> Generator[SqlStorage]:
169 """
170 Fixture of a Postgres backed SqlStorage engine.
172 See Also
173 --------
174 _create_storage_from_test_server_info
175 """
176 with _create_storage_from_test_server_info(
177 path_join(str(files("mlos_bench.config")), "storage", "postgresql.jsonc"),
178 postgres_storage_info,
179 shared_temp_dir,
180 short_testrun_uid,
181 ) as storage:
182 yield storage
185@pytest.fixture
186def sqlite_storage() -> Generator[SqlStorage]:
187 """
188 Fixture for file based SQLite storage in a temporary directory.
190 Yields
191 ------
192 Generator[SqlStorage]
194 Notes
195 -----
196 Can't be used in parallel tests on Windows.
197 """
198 with tempfile.TemporaryDirectory() as tmpdir:
199 db_path = os.path.join(tmpdir, "mlos_bench.sqlite")
200 config_str = json.dumps(
201 {
202 "class": "mlos_bench.storage.sql.storage.SqlStorage",
203 "config": {
204 "drivername": "sqlite",
205 "database": db_path,
206 "lazy_schema_create": False,
207 },
208 }
209 )
211 storage = from_config(config_str)
212 assert isinstance(storage, SqlStorage)
213 storage.update_schema()
214 yield storage
215 storage.dispose()
218@pytest.fixture
219def storage() -> SqlStorage:
220 """Test fixture for in-memory SQLite3 storage."""
221 return SqlStorage(
222 service=None,
223 config={
224 "drivername": "sqlite",
225 "database": ":memory:",
226 # "database": "mlos_bench.pytest.db",
227 },
228 )
231@pytest.fixture
232def exp_storage(
233 storage: SqlStorage,
234 tunable_groups: TunableGroups,
235) -> Generator[SqlStorage.Experiment]:
236 """
237 Test fixture for Experiment using in-memory SQLite3 storage.
239 Note: It has already entered the context upon return.
240 """
241 with storage.experiment(
242 experiment_id="Test-001",
243 trial_id=1,
244 root_env_config="environment.jsonc",
245 description="pytest experiment",
246 tunables=tunable_groups,
247 opt_targets={"score": "min"},
248 ) as exp:
249 yield exp
250 # pylint: disable=protected-access
251 assert not exp._in_context
254@pytest.fixture
255def exp_no_tunables_storage(
256 storage: SqlStorage,
257) -> Generator[SqlStorage.Experiment]:
258 """
259 Test fixture for Experiment using in-memory SQLite3 storage.
261 Note: It has already entered the context upon return.
262 """
263 empty_config: dict = {}
264 with storage.experiment(
265 experiment_id="Test-003",
266 trial_id=1,
267 root_env_config="environment.jsonc",
268 description="pytest experiment - no tunables",
269 tunables=TunableGroups(empty_config),
270 opt_targets={"score": "min"},
271 ) as exp:
272 yield exp
273 # pylint: disable=protected-access
274 assert not exp._in_context
277@pytest.fixture
278def mixed_numerics_exp_storage(
279 storage: SqlStorage,
280 mixed_numerics_tunable_groups: TunableGroups,
281) -> Generator[SqlStorage.Experiment]:
282 """
283 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
284 storage.
286 Note: It has already entered the context upon return.
287 """
288 with storage.experiment(
289 experiment_id="Test-002",
290 trial_id=1,
291 root_env_config="dne.jsonc",
292 description="pytest experiment",
293 tunables=mixed_numerics_tunable_groups,
294 opt_targets={"score": "min"},
295 ) as exp:
296 yield exp
297 # pylint: disable=protected-access
298 assert not exp._in_context
301def _dummy_run_exp(
302 storage: SqlStorage,
303 exp: SqlStorage.Experiment,
304) -> ExperimentData:
305 """
306 Generates data by doing a simulated run of the given experiment.
308 Parameters
309 ----------
310 storage : SqlStorage
311 The storage object to use.
312 exp : SqlStorage.Experiment
313 The experiment to "run".
314 Note: this particular object won't be updated, but a new one will be created
315 from its metadata.
317 Returns
318 -------
319 ExperimentData
320 The data generated by the simulated run.
321 """
322 # pylint: disable=too-many-locals
324 rand_seed(SEED)
326 trial_runners: list[TrialRunner] = []
327 global_config: dict = {}
328 config_loader = ConfigPersistenceService()
329 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names())
330 mock_env_json = f"""
331 {{
332 "class": "mlos_bench.environments.mock_env.MockEnv",
333 "name": "Test Env",
334 "config": {{
335 "tunable_params": [{tunable_params}],
336 "mock_env_seed": {SEED},
337 "mock_env_range": [60, 120],
338 "mock_env_metrics": ["score"]
339 }}
340 }}
341 """
342 trial_runners = TrialRunner.create_from_json(
343 config_loader=config_loader,
344 global_config=global_config,
345 tunable_groups=exp.tunables,
346 env_json=mock_env_json,
347 svcs_json=None,
348 num_trial_runners=TRIAL_RUNNER_COUNT,
349 )
351 opt = MockOptimizer(
352 tunables=exp.tunables,
353 config={
354 "optimization_targets": exp.opt_targets,
355 "seed": SEED,
356 # This should be the default, so we leave it omitted for now to test the default.
357 # But the test logic relies on this (e.g., trial 1 is config 1 is the
358 # default values for the tunable params)
359 # "start_with_defaults": True,
360 "max_suggestions": MAX_TRIALS,
361 },
362 global_config=global_config,
363 )
365 scheduler = SyncScheduler(
366 # All config values can be overridden from global config
367 config={
368 "experiment_id": exp.experiment_id,
369 "trial_id": exp.trial_id,
370 "config_id": -1,
371 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT,
372 "max_trials": MAX_TRIALS,
373 },
374 global_config=global_config,
375 trial_runners=trial_runners,
376 optimizer=opt,
377 storage=storage,
378 root_env_config=exp.root_env_config,
379 )
381 # Add some trial data to that experiment by "running" it.
382 with scheduler:
383 scheduler.start()
384 scheduler.teardown()
386 return storage.experiments[exp.experiment_id]
389@pytest.fixture
390def exp_data(
391 storage: SqlStorage,
392 exp_storage: SqlStorage.Experiment,
393) -> ExperimentData:
394 """Test fixture for ExperimentData."""
395 return _dummy_run_exp(storage, exp_storage)
398@pytest.fixture
399def exp_no_tunables_data(
400 storage: SqlStorage,
401 exp_no_tunables_storage: SqlStorage.Experiment,
402) -> ExperimentData:
403 """Test fixture for ExperimentData with no tunable configs."""
404 return _dummy_run_exp(storage, exp_no_tunables_storage)
407@pytest.fixture
408def mixed_numerics_exp_data(
409 storage: SqlStorage,
410 mixed_numerics_exp_storage: SqlStorage.Experiment,
411) -> ExperimentData:
412 """Test fixture for ExperimentData with mixed numerical tunable types."""
413 return _dummy_run_exp(storage, mixed_numerics_exp_storage)