Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%

116 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-30 00:51 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Test fixtures for mlos_bench storage.""" 

6 

7import json 

8import os 

9import tempfile 

10from collections.abc import Generator 

11from contextlib import contextmanager 

12from importlib.resources import files 

13from random import seed as rand_seed 

14 

15import pytest 

16from fasteners import InterProcessLock 

17from pytest_docker.plugin import Services as DockerServices 

18from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture 

19 

20from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

21from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

22from mlos_bench.schedulers.trial_runner import TrialRunner 

23from mlos_bench.services.config_persistence import ConfigPersistenceService 

24from mlos_bench.storage.base_experiment_data import ExperimentData 

25from mlos_bench.storage.sql.storage import SqlStorage 

26from mlos_bench.storage.storage_factory import from_config 

27from mlos_bench.tests import DOCKER, SEED, wait_docker_service_healthy 

28from mlos_bench.tests.storage import ( 

29 CONFIG_TRIAL_REPEAT_COUNT, 

30 MAX_TRIALS, 

31 TRIAL_RUNNER_COUNT, 

32) 

33from mlos_bench.tests.storage.sql import ( 

34 MYSQL_TEST_SERVER_NAME, 

35 PGSQL_TEST_SERVER_NAME, 

36 SqlTestServerInfo, 

37) 

38from mlos_bench.tunables.tunable_groups import TunableGroups 

39from mlos_bench.util import path_join 

40 

41# pylint: disable=redefined-outer-name 

42 

43# Try to test multiple DBMS engines. 

44DOCKER_DBMS_FIXTURES = [] 

45if DOCKER: 

46 DOCKER_DBMS_FIXTURES = [ 

47 lazy_fixture("mysql_storage"), 

48 lazy_fixture("postgres_storage"), 

49 ] 

50 

51PERSISTENT_SQL_STORAGE_FIXTURES = [lazy_fixture("sqlite_storage")] 

52if DOCKER: 

53 PERSISTENT_SQL_STORAGE_FIXTURES.extend(DOCKER_DBMS_FIXTURES) 

54 

55 

56@pytest.fixture(scope="session") 

57def docker_compose_file(pytestconfig: pytest.Config) -> list[str]: 

58 """ 

59 Fixture for the path to the docker-compose file. 

60 

61 Parameters 

62 ---------- 

63 pytestconfig : pytest.Config 

64 

65 Returns 

66 ------- 

67 list[str] 

68 List of paths to the docker-compose file(s). 

69 """ 

70 _ = pytestconfig # unused 

71 return [ 

72 os.path.join(os.path.dirname(__file__), "docker-compose.yml"), 

73 ] 

74 

75 

76@pytest.fixture(scope="session") 

77def docker_compose_project_name(short_testrun_uid: str) -> str: 

78 """ 

79 Fixture for the name of the docker-compose project. 

80 

81 Returns 

82 ------- 

83 str 

84 Name of the docker-compose project. 

85 """ 

86 return f"""mlos_bench-test-{short_testrun_uid}-{__name__.replace(".", "-")}""" 

87 

88 

89@pytest.fixture(scope="session") 

90def mysql_storage_info( 

91 docker_hostname: str, 

92 docker_compose_project_name: str, 

93 locked_docker_services: DockerServices, 

94) -> SqlTestServerInfo: 

95 """Fixture for getting mysql storage connection info.""" 

96 storage_info = SqlTestServerInfo( 

97 compose_project_name=docker_compose_project_name, 

98 service_name=MYSQL_TEST_SERVER_NAME, 

99 hostname=docker_hostname, 

100 ) 

101 wait_docker_service_healthy( 

102 locked_docker_services, 

103 storage_info.compose_project_name, 

104 storage_info.service_name, 

105 ) 

106 

107 return storage_info 

108 

109 

110@pytest.fixture(scope="session") 

111def postgres_storage_info( 

112 docker_hostname: str, 

113 docker_compose_project_name: str, 

114 locked_docker_services: DockerServices, 

115) -> SqlTestServerInfo: 

116 """Fixture for getting postgres storage connection info.""" 

117 storage_info = SqlTestServerInfo( 

118 compose_project_name=docker_compose_project_name, 

119 service_name=PGSQL_TEST_SERVER_NAME, 

120 hostname=docker_hostname, 

121 ) 

122 wait_docker_service_healthy( 

123 locked_docker_services, 

124 storage_info.compose_project_name, 

125 storage_info.service_name, 

126 ) 

127 return storage_info 

128 

129 

130@contextmanager 

131def _create_storage_from_test_server_info( 

132 config_file: str, 

133 test_server_info: SqlTestServerInfo, 

134 shared_temp_dir: str, 

135 short_testrun_uid: str, 

136) -> Generator[SqlStorage]: 

137 """ 

138 Creates a SqlStorage instance from the given test server info. 

139 

140 Notes 

141 ----- 

142 Resets the schema as a cleanup operation on return from the function scope 

143 fixture so each test gets a fresh storage instance. 

144 Uses a file lock to ensure that only one test can access the storage at a time. 

145 

146 Yields 

147 ------ 

148 SqlStorage 

149 """ 

150 sql_storage_name = test_server_info.service_name 

151 with InterProcessLock( 

152 path_join(shared_temp_dir, f"{sql_storage_name}-{short_testrun_uid}.lock") 

153 ): 

154 global_config = { 

155 "host": test_server_info.hostname, 

156 "port": test_server_info.get_port() or 0, 

157 "database": test_server_info.database, 

158 "username": test_server_info.username, 

159 "password": test_server_info.password, 

160 "lazy_schema_create": True, 

161 } 

162 storage = from_config( 

163 config_file, 

164 global_configs=[json.dumps(global_config)], 

165 ) 

166 assert isinstance(storage, SqlStorage) 

167 try: 

168 yield storage 

169 finally: 

170 # Cleanup the storage on return 

171 storage._reset_schema(force=True) # pylint: disable=protected-access 

172 

173 

174@pytest.fixture(scope="function") 

175def mysql_storage( 

176 mysql_storage_info: SqlTestServerInfo, 

177 shared_temp_dir: str, 

178 short_testrun_uid: str, 

179) -> Generator[SqlStorage]: 

180 """ 

181 Fixture of a MySQL backed SqlStorage engine. 

182 

183 See Also 

184 -------- 

185 _create_storage_from_test_server_info 

186 """ 

187 with _create_storage_from_test_server_info( 

188 path_join(str(files("mlos_bench.config")), "storage", "mysql.jsonc"), 

189 mysql_storage_info, 

190 shared_temp_dir, 

191 short_testrun_uid, 

192 ) as storage: 

193 yield storage 

194 

195 

196@pytest.fixture(scope="function") 

197def postgres_storage( 

198 postgres_storage_info: SqlTestServerInfo, 

199 shared_temp_dir: str, 

200 short_testrun_uid: str, 

201) -> Generator[SqlStorage]: 

202 """ 

203 Fixture of a Postgres backed SqlStorage engine. 

204 

205 See Also 

206 -------- 

207 _create_storage_from_test_server_info 

208 """ 

209 with _create_storage_from_test_server_info( 

210 path_join(str(files("mlos_bench.config")), "storage", "postgresql.jsonc"), 

211 postgres_storage_info, 

212 shared_temp_dir, 

213 short_testrun_uid, 

214 ) as storage: 

215 yield storage 

216 

217 

218@pytest.fixture 

219def sqlite_storage() -> Generator[SqlStorage]: 

220 """ 

221 Fixture for file based SQLite storage in a temporary directory. 

222 

223 Yields 

224 ------ 

225 Generator[SqlStorage] 

226 

227 Notes 

228 ----- 

229 Can't be used in parallel tests on Windows. 

230 """ 

231 with tempfile.TemporaryDirectory() as tmpdir: 

232 db_path = os.path.join(tmpdir, "mlos_bench.sqlite") 

233 config_str = json.dumps( 

234 { 

235 "class": "mlos_bench.storage.sql.storage.SqlStorage", 

236 "config": { 

237 "drivername": "sqlite", 

238 "database": db_path, 

239 "lazy_schema_create": False, 

240 }, 

241 } 

242 ) 

243 

244 storage = from_config(config_str) 

245 assert isinstance(storage, SqlStorage) 

246 storage.update_schema() 

247 yield storage 

248 storage.dispose() 

249 

250 

251@pytest.fixture 

252def storage() -> SqlStorage: 

253 """Test fixture for in-memory SQLite3 storage.""" 

254 return SqlStorage( 

255 service=None, 

256 config={ 

257 "drivername": "sqlite", 

258 "database": ":memory:", 

259 # "database": "mlos_bench.pytest.db", 

260 }, 

261 ) 

262 

263 

264@pytest.fixture 

265def exp_storage( 

266 storage: SqlStorage, 

267 tunable_groups: TunableGroups, 

268) -> Generator[SqlStorage.Experiment]: 

269 """ 

270 Test fixture for Experiment using in-memory SQLite3 storage. 

271 

272 Note: It has already entered the context upon return. 

273 """ 

274 with storage.experiment( 

275 experiment_id="Test-001", 

276 trial_id=1, 

277 root_env_config="environment.jsonc", 

278 description="pytest experiment", 

279 tunables=tunable_groups, 

280 opt_targets={"score": "min"}, 

281 ) as exp: 

282 yield exp 

283 # pylint: disable=protected-access 

284 assert not exp._in_context 

285 

286 

287@pytest.fixture 

288def exp_no_tunables_storage( 

289 storage: SqlStorage, 

290) -> Generator[SqlStorage.Experiment]: 

291 """ 

292 Test fixture for Experiment using in-memory SQLite3 storage. 

293 

294 Note: It has already entered the context upon return. 

295 """ 

296 empty_config: dict = {} 

297 with storage.experiment( 

298 experiment_id="Test-003", 

299 trial_id=1, 

300 root_env_config="environment.jsonc", 

301 description="pytest experiment - no tunables", 

302 tunables=TunableGroups(empty_config), 

303 opt_targets={"score": "min"}, 

304 ) as exp: 

305 yield exp 

306 # pylint: disable=protected-access 

307 assert not exp._in_context 

308 

309 

310@pytest.fixture 

311def mixed_numerics_exp_storage( 

312 storage: SqlStorage, 

313 mixed_numerics_tunable_groups: TunableGroups, 

314) -> Generator[SqlStorage.Experiment]: 

315 """ 

316 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 

317 storage. 

318 

319 Note: It has already entered the context upon return. 

320 """ 

321 with storage.experiment( 

322 experiment_id="Test-002", 

323 trial_id=1, 

324 root_env_config="dne.jsonc", 

325 description="pytest experiment", 

326 tunables=mixed_numerics_tunable_groups, 

327 opt_targets={"score": "min"}, 

328 ) as exp: 

329 yield exp 

330 # pylint: disable=protected-access 

331 assert not exp._in_context 

332 

333 

334def _dummy_run_exp( 

335 storage: SqlStorage, 

336 exp: SqlStorage.Experiment, 

337) -> ExperimentData: 

338 """ 

339 Generates data by doing a simulated run of the given experiment. 

340 

341 Parameters 

342 ---------- 

343 storage : SqlStorage 

344 The storage object to use. 

345 exp : SqlStorage.Experiment 

346 The experiment to "run". 

347 Note: this particular object won't be updated, but a new one will be created 

348 from its metadata. 

349 

350 Returns 

351 ------- 

352 ExperimentData 

353 The data generated by the simulated run. 

354 """ 

355 # pylint: disable=too-many-locals 

356 

357 rand_seed(SEED) 

358 

359 trial_runners: list[TrialRunner] = [] 

360 global_config: dict = {} 

361 config_loader = ConfigPersistenceService() 

362 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names()) 

363 mock_env_json = f""" 

364 {{ 

365 "class": "mlos_bench.environments.mock_env.MockEnv", 

366 "name": "Test Env", 

367 "config": {{ 

368 "tunable_params": [{tunable_params}], 

369 "mock_env_seed": {SEED}, 

370 "mock_env_range": [60, 120], 

371 "mock_env_metrics": ["score"] 

372 }} 

373 }} 

374 """ 

375 trial_runners = TrialRunner.create_from_json( 

376 config_loader=config_loader, 

377 global_config=global_config, 

378 tunable_groups=exp.tunables, 

379 env_json=mock_env_json, 

380 svcs_json=None, 

381 num_trial_runners=TRIAL_RUNNER_COUNT, 

382 ) 

383 

384 opt = MockOptimizer( 

385 tunables=exp.tunables, 

386 config={ 

387 "optimization_targets": exp.opt_targets, 

388 "seed": SEED, 

389 # This should be the default, so we leave it omitted for now to test the default. 

390 # But the test logic relies on this (e.g., trial 1 is config 1 is the 

391 # default values for the tunable params) 

392 # "start_with_defaults": True, 

393 "max_suggestions": MAX_TRIALS, 

394 }, 

395 global_config=global_config, 

396 ) 

397 

398 scheduler = SyncScheduler( 

399 # All config values can be overridden from global config 

400 config={ 

401 "experiment_id": exp.experiment_id, 

402 "trial_id": exp.trial_id, 

403 "config_id": -1, 

404 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, 

405 "max_trials": MAX_TRIALS, 

406 }, 

407 global_config=global_config, 

408 trial_runners=trial_runners, 

409 optimizer=opt, 

410 storage=storage, 

411 root_env_config=exp.root_env_config, 

412 ) 

413 

414 # Add some trial data to that experiment by "running" it. 

415 with scheduler: 

416 scheduler.start() 

417 scheduler.teardown() 

418 

419 return storage.experiments[exp.experiment_id] 

420 

421 

422@pytest.fixture 

423def exp_data( 

424 storage: SqlStorage, 

425 exp_storage: SqlStorage.Experiment, 

426) -> ExperimentData: 

427 """Test fixture for ExperimentData.""" 

428 return _dummy_run_exp(storage, exp_storage) 

429 

430 

431@pytest.fixture 

432def exp_no_tunables_data( 

433 storage: SqlStorage, 

434 exp_no_tunables_storage: SqlStorage.Experiment, 

435) -> ExperimentData: 

436 """Test fixture for ExperimentData with no tunable configs.""" 

437 return _dummy_run_exp(storage, exp_no_tunables_storage) 

438 

439 

440@pytest.fixture 

441def mixed_numerics_exp_data( 

442 storage: SqlStorage, 

443 mixed_numerics_exp_storage: SqlStorage.Experiment, 

444) -> ExperimentData: 

445 """Test fixture for ExperimentData with mixed numerical tunable types.""" 

446 return _dummy_run_exp(storage, mixed_numerics_exp_storage)