Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Test fixtures for mlos_bench storage.""" 

6 

7from random import seed as rand_seed 

8from typing import Generator 

9 

10import pytest 

11 

12from mlos_bench.environments.mock_env import MockEnv 

13from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

14from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

15from mlos_bench.storage.base_experiment_data import ExperimentData 

16from mlos_bench.storage.sql.storage import SqlStorage 

17from mlos_bench.tests import SEED 

18from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21# pylint: disable=redefined-outer-name 

22 

23 

24@pytest.fixture 

25def storage() -> SqlStorage: 

26 """Test fixture for in-memory SQLite3 storage.""" 

27 return SqlStorage( 

28 service=None, 

29 config={ 

30 "drivername": "sqlite", 

31 "database": ":memory:", 

32 # "database": "mlos_bench.pytest.db", 

33 }, 

34 ) 

35 

36 

37@pytest.fixture 

38def exp_storage( 

39 storage: SqlStorage, 

40 tunable_groups: TunableGroups, 

41) -> Generator[SqlStorage.Experiment, None, None]: 

42 """ 

43 Test fixture for Experiment using in-memory SQLite3 storage. 

44 

45 Note: It has already entered the context upon return. 

46 """ 

47 with storage.experiment( 

48 experiment_id="Test-001", 

49 trial_id=1, 

50 root_env_config="environment.jsonc", 

51 description="pytest experiment", 

52 tunables=tunable_groups, 

53 opt_targets={"score": "min"}, 

54 ) as exp: 

55 yield exp 

56 # pylint: disable=protected-access 

57 assert not exp._in_context 

58 

59 

60@pytest.fixture 

61def exp_no_tunables_storage( 

62 storage: SqlStorage, 

63) -> Generator[SqlStorage.Experiment, None, None]: 

64 """ 

65 Test fixture for Experiment using in-memory SQLite3 storage. 

66 

67 Note: It has already entered the context upon return. 

68 """ 

69 empty_config: dict = {} 

70 with storage.experiment( 

71 experiment_id="Test-003", 

72 trial_id=1, 

73 root_env_config="environment.jsonc", 

74 description="pytest experiment - no tunables", 

75 tunables=TunableGroups(empty_config), 

76 opt_targets={"score": "min"}, 

77 ) as exp: 

78 yield exp 

79 # pylint: disable=protected-access 

80 assert not exp._in_context 

81 

82 

83@pytest.fixture 

84def mixed_numerics_exp_storage( 

85 storage: SqlStorage, 

86 mixed_numerics_tunable_groups: TunableGroups, 

87) -> Generator[SqlStorage.Experiment, None, None]: 

88 """ 

89 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 

90 storage. 

91 

92 Note: It has already entered the context upon return. 

93 """ 

94 with storage.experiment( 

95 experiment_id="Test-002", 

96 trial_id=1, 

97 root_env_config="dne.jsonc", 

98 description="pytest experiment", 

99 tunables=mixed_numerics_tunable_groups, 

100 opt_targets={"score": "min"}, 

101 ) as exp: 

102 yield exp 

103 # pylint: disable=protected-access 

104 assert not exp._in_context 

105 

106 

107def _dummy_run_exp( 

108 storage: SqlStorage, 

109 exp: SqlStorage.Experiment, 

110) -> ExperimentData: 

111 """ 

112 Generates data by doing a simulated run of the given experiment. 

113 

114 Parameters 

115 ---------- 

116 storage : SqlStorage 

117 The storage object to use. 

118 exp : SqlStorage.Experiment 

119 The experiment to "run". 

120 Note: this particular object won't be updated, but a new one will be created 

121 from its metadata. 

122 

123 Returns 

124 ------- 

125 ExperimentData 

126 The data generated by the simulated run. 

127 """ 

128 # pylint: disable=too-many-locals 

129 

130 rand_seed(SEED) 

131 

132 env = MockEnv( 

133 name="Test Env", 

134 config={ 

135 "tunable_params": list(exp.tunables.get_covariant_group_names()), 

136 "mock_env_seed": SEED, 

137 "mock_env_range": [60, 120], 

138 "mock_env_metrics": ["score"], 

139 }, 

140 tunables=exp.tunables, 

141 ) 

142 

143 opt = MockOptimizer( 

144 tunables=exp.tunables, 

145 config={ 

146 "optimization_targets": exp.opt_targets, 

147 "seed": SEED, 

148 # This should be the default, so we leave it omitted for now to test the default. 

149 # But the test logic relies on this (e.g., trial 1 is config 1 is the 

150 # default values for the tunable params) 

151 # "start_with_defaults": True, 

152 }, 

153 ) 

154 

155 scheduler = SyncScheduler( 

156 # All config values can be overridden from global config 

157 config={ 

158 "experiment_id": exp.experiment_id, 

159 "trial_id": exp.trial_id, 

160 "config_id": -1, 

161 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, 

162 "max_trials": CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, 

163 }, 

164 global_config={}, 

165 environment=env, 

166 optimizer=opt, 

167 storage=storage, 

168 root_env_config=exp.root_env_config, 

169 ) 

170 

171 # Add some trial data to that experiment by "running" it. 

172 with scheduler: 

173 scheduler.start() 

174 scheduler.teardown() 

175 

176 return storage.experiments[exp.experiment_id] 

177 

178 

179@pytest.fixture 

180def exp_data( 

181 storage: SqlStorage, 

182 exp_storage: SqlStorage.Experiment, 

183) -> ExperimentData: 

184 """Test fixture for ExperimentData.""" 

185 return _dummy_run_exp(storage, exp_storage) 

186 

187 

188@pytest.fixture 

189def exp_no_tunables_data( 

190 storage: SqlStorage, 

191 exp_no_tunables_storage: SqlStorage.Experiment, 

192) -> ExperimentData: 

193 """Test fixture for ExperimentData with no tunable configs.""" 

194 return _dummy_run_exp(storage, exp_no_tunables_storage) 

195 

196 

197@pytest.fixture 

198def mixed_numerics_exp_data( 

199 storage: SqlStorage, 

200 mixed_numerics_exp_storage: SqlStorage.Experiment, 

201) -> ExperimentData: 

202 """Test fixture for ExperimentData with mixed numerical tunable types.""" 

203 return _dummy_run_exp(storage, mixed_numerics_exp_storage)