Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%
116 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-30 00:51 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Test fixtures for mlos_bench storage."""
7import json
8import os
9import tempfile
10from collections.abc import Generator
11from contextlib import contextmanager
12from importlib.resources import files
13from random import seed as rand_seed
15import pytest
16from fasteners import InterProcessLock
17from pytest_docker.plugin import Services as DockerServices
18from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture
20from mlos_bench.optimizers.mock_optimizer import MockOptimizer
21from mlos_bench.schedulers.sync_scheduler import SyncScheduler
22from mlos_bench.schedulers.trial_runner import TrialRunner
23from mlos_bench.services.config_persistence import ConfigPersistenceService
24from mlos_bench.storage.base_experiment_data import ExperimentData
25from mlos_bench.storage.sql.storage import SqlStorage
26from mlos_bench.storage.storage_factory import from_config
27from mlos_bench.tests import DOCKER, SEED, wait_docker_service_healthy
28from mlos_bench.tests.storage import (
29 CONFIG_TRIAL_REPEAT_COUNT,
30 MAX_TRIALS,
31 TRIAL_RUNNER_COUNT,
32)
33from mlos_bench.tests.storage.sql import (
34 MYSQL_TEST_SERVER_NAME,
35 PGSQL_TEST_SERVER_NAME,
36 SqlTestServerInfo,
37)
38from mlos_bench.tunables.tunable_groups import TunableGroups
39from mlos_bench.util import path_join
41# pylint: disable=redefined-outer-name
43# Try to test multiple DBMS engines.
44DOCKER_DBMS_FIXTURES = []
45if DOCKER:
46 DOCKER_DBMS_FIXTURES = [
47 lazy_fixture("mysql_storage"),
48 lazy_fixture("postgres_storage"),
49 ]
51PERSISTENT_SQL_STORAGE_FIXTURES = [lazy_fixture("sqlite_storage")]
52if DOCKER:
53 PERSISTENT_SQL_STORAGE_FIXTURES.extend(DOCKER_DBMS_FIXTURES)
56@pytest.fixture(scope="session")
57def docker_compose_file(pytestconfig: pytest.Config) -> list[str]:
58 """
59 Fixture for the path to the docker-compose file.
61 Parameters
62 ----------
63 pytestconfig : pytest.Config
65 Returns
66 -------
67 list[str]
68 List of paths to the docker-compose file(s).
69 """
70 _ = pytestconfig # unused
71 return [
72 os.path.join(os.path.dirname(__file__), "docker-compose.yml"),
73 ]
76@pytest.fixture(scope="session")
77def docker_compose_project_name(short_testrun_uid: str) -> str:
78 """
79 Fixture for the name of the docker-compose project.
81 Returns
82 -------
83 str
84 Name of the docker-compose project.
85 """
86 return f"""mlos_bench-test-{short_testrun_uid}-{__name__.replace(".", "-")}"""
89@pytest.fixture(scope="session")
90def mysql_storage_info(
91 docker_hostname: str,
92 docker_compose_project_name: str,
93 locked_docker_services: DockerServices,
94) -> SqlTestServerInfo:
95 """Fixture for getting mysql storage connection info."""
96 storage_info = SqlTestServerInfo(
97 compose_project_name=docker_compose_project_name,
98 service_name=MYSQL_TEST_SERVER_NAME,
99 hostname=docker_hostname,
100 )
101 wait_docker_service_healthy(
102 locked_docker_services,
103 storage_info.compose_project_name,
104 storage_info.service_name,
105 )
107 return storage_info
110@pytest.fixture(scope="session")
111def postgres_storage_info(
112 docker_hostname: str,
113 docker_compose_project_name: str,
114 locked_docker_services: DockerServices,
115) -> SqlTestServerInfo:
116 """Fixture for getting postgres storage connection info."""
117 storage_info = SqlTestServerInfo(
118 compose_project_name=docker_compose_project_name,
119 service_name=PGSQL_TEST_SERVER_NAME,
120 hostname=docker_hostname,
121 )
122 wait_docker_service_healthy(
123 locked_docker_services,
124 storage_info.compose_project_name,
125 storage_info.service_name,
126 )
127 return storage_info
130@contextmanager
131def _create_storage_from_test_server_info(
132 config_file: str,
133 test_server_info: SqlTestServerInfo,
134 shared_temp_dir: str,
135 short_testrun_uid: str,
136) -> Generator[SqlStorage]:
137 """
138 Creates a SqlStorage instance from the given test server info.
140 Notes
141 -----
142 Resets the schema as a cleanup operation on return from the function scope
143 fixture so each test gets a fresh storage instance.
144 Uses a file lock to ensure that only one test can access the storage at a time.
146 Yields
147 ------
148 SqlStorage
149 """
150 sql_storage_name = test_server_info.service_name
151 with InterProcessLock(
152 path_join(shared_temp_dir, f"{sql_storage_name}-{short_testrun_uid}.lock")
153 ):
154 global_config = {
155 "host": test_server_info.hostname,
156 "port": test_server_info.get_port() or 0,
157 "database": test_server_info.database,
158 "username": test_server_info.username,
159 "password": test_server_info.password,
160 "lazy_schema_create": True,
161 }
162 storage = from_config(
163 config_file,
164 global_configs=[json.dumps(global_config)],
165 )
166 assert isinstance(storage, SqlStorage)
167 try:
168 yield storage
169 finally:
170 # Cleanup the storage on return
171 storage._reset_schema(force=True) # pylint: disable=protected-access
174@pytest.fixture(scope="function")
175def mysql_storage(
176 mysql_storage_info: SqlTestServerInfo,
177 shared_temp_dir: str,
178 short_testrun_uid: str,
179) -> Generator[SqlStorage]:
180 """
181 Fixture of a MySQL backed SqlStorage engine.
183 See Also
184 --------
185 _create_storage_from_test_server_info
186 """
187 with _create_storage_from_test_server_info(
188 path_join(str(files("mlos_bench.config")), "storage", "mysql.jsonc"),
189 mysql_storage_info,
190 shared_temp_dir,
191 short_testrun_uid,
192 ) as storage:
193 yield storage
196@pytest.fixture(scope="function")
197def postgres_storage(
198 postgres_storage_info: SqlTestServerInfo,
199 shared_temp_dir: str,
200 short_testrun_uid: str,
201) -> Generator[SqlStorage]:
202 """
203 Fixture of a Postgres backed SqlStorage engine.
205 See Also
206 --------
207 _create_storage_from_test_server_info
208 """
209 with _create_storage_from_test_server_info(
210 path_join(str(files("mlos_bench.config")), "storage", "postgresql.jsonc"),
211 postgres_storage_info,
212 shared_temp_dir,
213 short_testrun_uid,
214 ) as storage:
215 yield storage
218@pytest.fixture
219def sqlite_storage() -> Generator[SqlStorage]:
220 """
221 Fixture for file based SQLite storage in a temporary directory.
223 Yields
224 ------
225 Generator[SqlStorage]
227 Notes
228 -----
229 Can't be used in parallel tests on Windows.
230 """
231 with tempfile.TemporaryDirectory() as tmpdir:
232 db_path = os.path.join(tmpdir, "mlos_bench.sqlite")
233 config_str = json.dumps(
234 {
235 "class": "mlos_bench.storage.sql.storage.SqlStorage",
236 "config": {
237 "drivername": "sqlite",
238 "database": db_path,
239 "lazy_schema_create": False,
240 },
241 }
242 )
244 storage = from_config(config_str)
245 assert isinstance(storage, SqlStorage)
246 storage.update_schema()
247 yield storage
248 storage.dispose()
251@pytest.fixture
252def storage() -> SqlStorage:
253 """Test fixture for in-memory SQLite3 storage."""
254 return SqlStorage(
255 service=None,
256 config={
257 "drivername": "sqlite",
258 "database": ":memory:",
259 # "database": "mlos_bench.pytest.db",
260 },
261 )
264@pytest.fixture
265def exp_storage(
266 storage: SqlStorage,
267 tunable_groups: TunableGroups,
268) -> Generator[SqlStorage.Experiment]:
269 """
270 Test fixture for Experiment using in-memory SQLite3 storage.
272 Note: It has already entered the context upon return.
273 """
274 with storage.experiment(
275 experiment_id="Test-001",
276 trial_id=1,
277 root_env_config="environment.jsonc",
278 description="pytest experiment",
279 tunables=tunable_groups,
280 opt_targets={"score": "min"},
281 ) as exp:
282 yield exp
283 # pylint: disable=protected-access
284 assert not exp._in_context
287@pytest.fixture
288def exp_no_tunables_storage(
289 storage: SqlStorage,
290) -> Generator[SqlStorage.Experiment]:
291 """
292 Test fixture for Experiment using in-memory SQLite3 storage.
294 Note: It has already entered the context upon return.
295 """
296 empty_config: dict = {}
297 with storage.experiment(
298 experiment_id="Test-003",
299 trial_id=1,
300 root_env_config="environment.jsonc",
301 description="pytest experiment - no tunables",
302 tunables=TunableGroups(empty_config),
303 opt_targets={"score": "min"},
304 ) as exp:
305 yield exp
306 # pylint: disable=protected-access
307 assert not exp._in_context
310@pytest.fixture
311def mixed_numerics_exp_storage(
312 storage: SqlStorage,
313 mixed_numerics_tunable_groups: TunableGroups,
314) -> Generator[SqlStorage.Experiment]:
315 """
316 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
317 storage.
319 Note: It has already entered the context upon return.
320 """
321 with storage.experiment(
322 experiment_id="Test-002",
323 trial_id=1,
324 root_env_config="dne.jsonc",
325 description="pytest experiment",
326 tunables=mixed_numerics_tunable_groups,
327 opt_targets={"score": "min"},
328 ) as exp:
329 yield exp
330 # pylint: disable=protected-access
331 assert not exp._in_context
334def _dummy_run_exp(
335 storage: SqlStorage,
336 exp: SqlStorage.Experiment,
337) -> ExperimentData:
338 """
339 Generates data by doing a simulated run of the given experiment.
341 Parameters
342 ----------
343 storage : SqlStorage
344 The storage object to use.
345 exp : SqlStorage.Experiment
346 The experiment to "run".
347 Note: this particular object won't be updated, but a new one will be created
348 from its metadata.
350 Returns
351 -------
352 ExperimentData
353 The data generated by the simulated run.
354 """
355 # pylint: disable=too-many-locals
357 rand_seed(SEED)
359 trial_runners: list[TrialRunner] = []
360 global_config: dict = {}
361 config_loader = ConfigPersistenceService()
362 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names())
363 mock_env_json = f"""
364 {{
365 "class": "mlos_bench.environments.mock_env.MockEnv",
366 "name": "Test Env",
367 "config": {{
368 "tunable_params": [{tunable_params}],
369 "mock_env_seed": {SEED},
370 "mock_env_range": [60, 120],
371 "mock_env_metrics": ["score"]
372 }}
373 }}
374 """
375 trial_runners = TrialRunner.create_from_json(
376 config_loader=config_loader,
377 global_config=global_config,
378 tunable_groups=exp.tunables,
379 env_json=mock_env_json,
380 svcs_json=None,
381 num_trial_runners=TRIAL_RUNNER_COUNT,
382 )
384 opt = MockOptimizer(
385 tunables=exp.tunables,
386 config={
387 "optimization_targets": exp.opt_targets,
388 "seed": SEED,
389 # This should be the default, so we leave it omitted for now to test the default.
390 # But the test logic relies on this (e.g., trial 1 is config 1 is the
391 # default values for the tunable params)
392 # "start_with_defaults": True,
393 "max_suggestions": MAX_TRIALS,
394 },
395 global_config=global_config,
396 )
398 scheduler = SyncScheduler(
399 # All config values can be overridden from global config
400 config={
401 "experiment_id": exp.experiment_id,
402 "trial_id": exp.trial_id,
403 "config_id": -1,
404 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT,
405 "max_trials": MAX_TRIALS,
406 },
407 global_config=global_config,
408 trial_runners=trial_runners,
409 optimizer=opt,
410 storage=storage,
411 root_env_config=exp.root_env_config,
412 )
414 # Add some trial data to that experiment by "running" it.
415 with scheduler:
416 scheduler.start()
417 scheduler.teardown()
419 return storage.experiments[exp.experiment_id]
422@pytest.fixture
423def exp_data(
424 storage: SqlStorage,
425 exp_storage: SqlStorage.Experiment,
426) -> ExperimentData:
427 """Test fixture for ExperimentData."""
428 return _dummy_run_exp(storage, exp_storage)
431@pytest.fixture
432def exp_no_tunables_data(
433 storage: SqlStorage,
434 exp_no_tunables_storage: SqlStorage.Experiment,
435) -> ExperimentData:
436 """Test fixture for ExperimentData with no tunable configs."""
437 return _dummy_run_exp(storage, exp_no_tunables_storage)
440@pytest.fixture
441def mixed_numerics_exp_data(
442 storage: SqlStorage,
443 mixed_numerics_exp_storage: SqlStorage.Experiment,
444) -> ExperimentData:
445 """Test fixture for ExperimentData with mixed numerical tunable types."""
446 return _dummy_run_exp(storage, mixed_numerics_exp_storage)