Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%

109 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Test fixtures for mlos_bench storage.""" 

6 

7import json 

8import os 

9import tempfile 

10from collections.abc import Generator 

11from contextlib import contextmanager 

12from importlib.resources import files 

13from random import seed as rand_seed 

14 

15import pytest 

16from fasteners import InterProcessLock 

17from pytest_docker.plugin import Services as DockerServices 

18from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture 

19 

20from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

21from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

22from mlos_bench.schedulers.trial_runner import TrialRunner 

23from mlos_bench.services.config_persistence import ConfigPersistenceService 

24from mlos_bench.storage.base_experiment_data import ExperimentData 

25from mlos_bench.storage.sql.storage import SqlStorage 

26from mlos_bench.storage.storage_factory import from_config 

27from mlos_bench.tests import DOCKER, SEED, wait_docker_service_healthy 

28from mlos_bench.tests.storage import ( 

29 CONFIG_TRIAL_REPEAT_COUNT, 

30 MAX_TRIALS, 

31 TRIAL_RUNNER_COUNT, 

32) 

33from mlos_bench.tests.storage.sql import ( 

34 MYSQL_TEST_SERVER_NAME, 

35 PGSQL_TEST_SERVER_NAME, 

36 SqlTestServerInfo, 

37) 

38from mlos_bench.tunables.tunable_groups import TunableGroups 

39from mlos_bench.util import path_join 

40 

41# pylint: disable=redefined-outer-name 

42 

43# Try to test multiple DBMS engines. 

44DOCKER_DBMS_FIXTURES = [] 

45if DOCKER: 

46 DOCKER_DBMS_FIXTURES = [ 

47 lazy_fixture("mysql_storage"), 

48 lazy_fixture("postgres_storage"), 

49 ] 

50 

51PERSISTENT_SQL_STORAGE_FIXTURES = [lazy_fixture("sqlite_storage")] 

52if DOCKER: 

53 PERSISTENT_SQL_STORAGE_FIXTURES.extend(DOCKER_DBMS_FIXTURES) 

54 

55 

56@pytest.fixture(scope="session") 

57def mysql_storage_info( 

58 docker_hostname: str, 

59 docker_compose_project_name: str, 

60 locked_docker_services: DockerServices, 

61) -> SqlTestServerInfo: 

62 """Fixture for getting mysql storage connection info.""" 

63 storage_info = SqlTestServerInfo( 

64 compose_project_name=docker_compose_project_name, 

65 service_name=MYSQL_TEST_SERVER_NAME, 

66 hostname=docker_hostname, 

67 ) 

68 wait_docker_service_healthy( 

69 locked_docker_services, 

70 storage_info.compose_project_name, 

71 storage_info.service_name, 

72 ) 

73 

74 return storage_info 

75 

76 

77@pytest.fixture(scope="session") 

78def postgres_storage_info( 

79 docker_hostname: str, 

80 docker_compose_project_name: str, 

81 locked_docker_services: DockerServices, 

82) -> SqlTestServerInfo: 

83 """Fixture for getting postgres storage connection info.""" 

84 storage_info = SqlTestServerInfo( 

85 compose_project_name=docker_compose_project_name, 

86 service_name=PGSQL_TEST_SERVER_NAME, 

87 hostname=docker_hostname, 

88 ) 

89 wait_docker_service_healthy( 

90 locked_docker_services, 

91 storage_info.compose_project_name, 

92 storage_info.service_name, 

93 ) 

94 return storage_info 

95 

96 

97@contextmanager 

98def _create_storage_from_test_server_info( 

99 config_file: str, 

100 test_server_info: SqlTestServerInfo, 

101 shared_temp_dir: str, 

102 short_testrun_uid: str, 

103) -> Generator[SqlStorage]: 

104 """ 

105 Creates a SqlStorage instance from the given test server info. 

106 

107 Notes 

108 ----- 

109 Resets the schema as a cleanup operation on return from the function scope 

110 fixture so each test gets a fresh storage instance. 

111 Uses a file lock to ensure that only one test can access the storage at a time. 

112 

113 Yields 

114 ------ 

115 SqlStorage 

116 """ 

117 sql_storage_name = test_server_info.service_name 

118 with InterProcessLock( 

119 path_join(shared_temp_dir, f"{sql_storage_name}-{short_testrun_uid}.lock") 

120 ): 

121 global_config = { 

122 "host": test_server_info.hostname, 

123 "port": test_server_info.get_port() or 0, 

124 "database": test_server_info.database, 

125 "username": test_server_info.username, 

126 "password": test_server_info.password, 

127 "lazy_schema_create": True, 

128 } 

129 storage = from_config( 

130 config_file, 

131 global_configs=[json.dumps(global_config)], 

132 ) 

133 assert isinstance(storage, SqlStorage) 

134 try: 

135 yield storage 

136 finally: 

137 # Cleanup the storage on return 

138 storage._reset_schema(force=True) # pylint: disable=protected-access 

139 

140 

141@pytest.fixture(scope="function") 

142def mysql_storage( 

143 mysql_storage_info: SqlTestServerInfo, 

144 shared_temp_dir: str, 

145 short_testrun_uid: str, 

146) -> Generator[SqlStorage]: 

147 """ 

148 Fixture of a MySQL backed SqlStorage engine. 

149 

150 See Also 

151 -------- 

152 _create_storage_from_test_server_info 

153 """ 

154 with _create_storage_from_test_server_info( 

155 path_join(str(files("mlos_bench.config")), "storage", "mysql.jsonc"), 

156 mysql_storage_info, 

157 shared_temp_dir, 

158 short_testrun_uid, 

159 ) as storage: 

160 yield storage 

161 

162 

163@pytest.fixture(scope="function") 

164def postgres_storage( 

165 postgres_storage_info: SqlTestServerInfo, 

166 shared_temp_dir: str, 

167 short_testrun_uid: str, 

168) -> Generator[SqlStorage]: 

169 """ 

170 Fixture of a Postgres backed SqlStorage engine. 

171 

172 See Also 

173 -------- 

174 _create_storage_from_test_server_info 

175 """ 

176 with _create_storage_from_test_server_info( 

177 path_join(str(files("mlos_bench.config")), "storage", "postgresql.jsonc"), 

178 postgres_storage_info, 

179 shared_temp_dir, 

180 short_testrun_uid, 

181 ) as storage: 

182 yield storage 

183 

184 

185@pytest.fixture 

186def sqlite_storage() -> Generator[SqlStorage]: 

187 """ 

188 Fixture for file based SQLite storage in a temporary directory. 

189 

190 Yields 

191 ------ 

192 Generator[SqlStorage] 

193 

194 Notes 

195 ----- 

196 Can't be used in parallel tests on Windows. 

197 """ 

198 with tempfile.TemporaryDirectory() as tmpdir: 

199 db_path = os.path.join(tmpdir, "mlos_bench.sqlite") 

200 config_str = json.dumps( 

201 { 

202 "class": "mlos_bench.storage.sql.storage.SqlStorage", 

203 "config": { 

204 "drivername": "sqlite", 

205 "database": db_path, 

206 "lazy_schema_create": False, 

207 }, 

208 } 

209 ) 

210 

211 storage = from_config(config_str) 

212 assert isinstance(storage, SqlStorage) 

213 storage.update_schema() 

214 yield storage 

215 storage.dispose() 

216 

217 

218@pytest.fixture 

219def storage() -> SqlStorage: 

220 """Test fixture for in-memory SQLite3 storage.""" 

221 return SqlStorage( 

222 service=None, 

223 config={ 

224 "drivername": "sqlite", 

225 "database": ":memory:", 

226 # "database": "mlos_bench.pytest.db", 

227 }, 

228 ) 

229 

230 

231@pytest.fixture 

232def exp_storage( 

233 storage: SqlStorage, 

234 tunable_groups: TunableGroups, 

235) -> Generator[SqlStorage.Experiment]: 

236 """ 

237 Test fixture for Experiment using in-memory SQLite3 storage. 

238 

239 Note: It has already entered the context upon return. 

240 """ 

241 with storage.experiment( 

242 experiment_id="Test-001", 

243 trial_id=1, 

244 root_env_config="environment.jsonc", 

245 description="pytest experiment", 

246 tunables=tunable_groups, 

247 opt_targets={"score": "min"}, 

248 ) as exp: 

249 yield exp 

250 # pylint: disable=protected-access 

251 assert not exp._in_context 

252 

253 

254@pytest.fixture 

255def exp_no_tunables_storage( 

256 storage: SqlStorage, 

257) -> Generator[SqlStorage.Experiment]: 

258 """ 

259 Test fixture for Experiment using in-memory SQLite3 storage. 

260 

261 Note: It has already entered the context upon return. 

262 """ 

263 empty_config: dict = {} 

264 with storage.experiment( 

265 experiment_id="Test-003", 

266 trial_id=1, 

267 root_env_config="environment.jsonc", 

268 description="pytest experiment - no tunables", 

269 tunables=TunableGroups(empty_config), 

270 opt_targets={"score": "min"}, 

271 ) as exp: 

272 yield exp 

273 # pylint: disable=protected-access 

274 assert not exp._in_context 

275 

276 

277@pytest.fixture 

278def mixed_numerics_exp_storage( 

279 storage: SqlStorage, 

280 mixed_numerics_tunable_groups: TunableGroups, 

281) -> Generator[SqlStorage.Experiment]: 

282 """ 

283 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 

284 storage. 

285 

286 Note: It has already entered the context upon return. 

287 """ 

288 with storage.experiment( 

289 experiment_id="Test-002", 

290 trial_id=1, 

291 root_env_config="dne.jsonc", 

292 description="pytest experiment", 

293 tunables=mixed_numerics_tunable_groups, 

294 opt_targets={"score": "min"}, 

295 ) as exp: 

296 yield exp 

297 # pylint: disable=protected-access 

298 assert not exp._in_context 

299 

300 

301def _dummy_run_exp( 

302 storage: SqlStorage, 

303 exp: SqlStorage.Experiment, 

304) -> ExperimentData: 

305 """ 

306 Generates data by doing a simulated run of the given experiment. 

307 

308 Parameters 

309 ---------- 

310 storage : SqlStorage 

311 The storage object to use. 

312 exp : SqlStorage.Experiment 

313 The experiment to "run". 

314 Note: this particular object won't be updated, but a new one will be created 

315 from its metadata. 

316 

317 Returns 

318 ------- 

319 ExperimentData 

320 The data generated by the simulated run. 

321 """ 

322 # pylint: disable=too-many-locals 

323 

324 rand_seed(SEED) 

325 

326 trial_runners: list[TrialRunner] = [] 

327 global_config: dict = {} 

328 config_loader = ConfigPersistenceService() 

329 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names()) 

330 mock_env_json = f""" 

331 {{ 

332 "class": "mlos_bench.environments.mock_env.MockEnv", 

333 "name": "Test Env", 

334 "config": {{ 

335 "tunable_params": [{tunable_params}], 

336 "mock_env_seed": {SEED}, 

337 "mock_env_range": [60, 120], 

338 "mock_env_metrics": ["score"] 

339 }} 

340 }} 

341 """ 

342 trial_runners = TrialRunner.create_from_json( 

343 config_loader=config_loader, 

344 global_config=global_config, 

345 tunable_groups=exp.tunables, 

346 env_json=mock_env_json, 

347 svcs_json=None, 

348 num_trial_runners=TRIAL_RUNNER_COUNT, 

349 ) 

350 

351 opt = MockOptimizer( 

352 tunables=exp.tunables, 

353 config={ 

354 "optimization_targets": exp.opt_targets, 

355 "seed": SEED, 

356 # This should be the default, so we leave it omitted for now to test the default. 

357 # But the test logic relies on this (e.g., trial 1 is config 1 is the 

358 # default values for the tunable params) 

359 # "start_with_defaults": True, 

360 "max_suggestions": MAX_TRIALS, 

361 }, 

362 global_config=global_config, 

363 ) 

364 

365 scheduler = SyncScheduler( 

366 # All config values can be overridden from global config 

367 config={ 

368 "experiment_id": exp.experiment_id, 

369 "trial_id": exp.trial_id, 

370 "config_id": -1, 

371 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, 

372 "max_trials": MAX_TRIALS, 

373 }, 

374 global_config=global_config, 

375 trial_runners=trial_runners, 

376 optimizer=opt, 

377 storage=storage, 

378 root_env_config=exp.root_env_config, 

379 ) 

380 

381 # Add some trial data to that experiment by "running" it. 

382 with scheduler: 

383 scheduler.start() 

384 scheduler.teardown() 

385 

386 return storage.experiments[exp.experiment_id] 

387 

388 

389@pytest.fixture 

390def exp_data( 

391 storage: SqlStorage, 

392 exp_storage: SqlStorage.Experiment, 

393) -> ExperimentData: 

394 """Test fixture for ExperimentData.""" 

395 return _dummy_run_exp(storage, exp_storage) 

396 

397 

398@pytest.fixture 

399def exp_no_tunables_data( 

400 storage: SqlStorage, 

401 exp_no_tunables_storage: SqlStorage.Experiment, 

402) -> ExperimentData: 

403 """Test fixture for ExperimentData with no tunable configs.""" 

404 return _dummy_run_exp(storage, exp_no_tunables_storage) 

405 

406 

407@pytest.fixture 

408def mixed_numerics_exp_data( 

409 storage: SqlStorage, 

410 mixed_numerics_exp_storage: SqlStorage.Experiment, 

411) -> ExperimentData: 

412 """Test fixture for ExperimentData with mixed numerical tunable types.""" 

413 return _dummy_run_exp(storage, mixed_numerics_exp_storage)