Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%

121 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for grid search mlos_bench optimizer. 

7""" 

8 

9from typing import Dict, List 

10 

11import itertools 

12import math 

13import random 

14 

15import pytest 

16 

17from mlos_bench.environments.status import Status 

18from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer 

19from mlos_bench.tunables.tunable import TunableValue 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21 

22 

23# pylint: disable=redefined-outer-name 

24 

25@pytest.fixture 

26def grid_search_tunables_config() -> dict: 

27 """ 

28 Test fixture for grid search optimizer tunables config. 

29 """ 

30 return { 

31 "grid": { 

32 "cost": 1, 

33 "params": { 

34 "cat": { 

35 "type": "categorical", 

36 "values": ["a", "b", "c"], 

37 "default": "a", 

38 }, 

39 "int": { 

40 "type": "int", 

41 "range": [1, 3], 

42 "default": 2, 

43 }, 

44 "float": { 

45 "type": "float", 

46 "range": [0, 1], 

47 "default": 0.5, 

48 "quantization": 0.25, 

49 }, 

50 }, 

51 }, 

52 } 

53 

54 

55@pytest.fixture 

56def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: 

57 """ 

58 Test fixture for grid from tunable groups. 

59 Used to check that the grids are the same (ignoring order). 

60 """ 

61 tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] 

62 tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) 

63 return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) 

64 

65 

66@pytest.fixture 

67def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: 

68 """ 

69 Test fixture for grid search optimizer tunables. 

70 """ 

71 return TunableGroups(grid_search_tunables_config) 

72 

73 

74@pytest.fixture 

75def grid_search_opt(grid_search_tunables: TunableGroups, 

76 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: 

77 """ 

78 Test fixture for grid search optimizer. 

79 """ 

80 assert len(grid_search_tunables) == 3 

81 # Test the convergence logic by controlling the number of iterations to be not a 

82 # multiple of the number of elements in the grid. 

83 max_iterations = len(grid_search_tunables_grid) * 2 - 3 

84 return GridSearchOptimizer(tunables=grid_search_tunables, config={ 

85 "max_suggestions": max_iterations, 

86 "optimization_targets": {"score": "max"}, 

87 }) 

88 

89 

90def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, 

91 grid_search_tunables: TunableGroups, 

92 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: 

93 """ 

94 Make sure that grid search optimizer initializes and works correctly. 

95 """ 

96 # Check the size. 

97 expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables) 

98 assert expected_grid_size > len(grid_search_tunables) 

99 assert len(grid_search_tunables_grid) == expected_grid_size 

100 # Check for specific example configs inclusion. 

101 expected_config_example: Dict[str, TunableValue] = { 

102 "cat": "a", 

103 "int": 2, 

104 "float": 0.75, 

105 } 

106 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs) 

107 assert expected_config_example in grid_search_tunables_grid 

108 assert expected_config_example in grid_search_opt_pending_configs 

109 # Check the rest of the contents. 

110 # Note: ConfigSpace param name vs TunableGroup parameter name order is not 

111 # consistent, so we need to full dict comparison. 

112 assert len(grid_search_opt_pending_configs) == expected_grid_size 

113 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs) 

114 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid) 

115 # Order is less relevant to us, so we'll just check that the sets are the same. 

116 # assert grid_search_opt.pending_configs == grid_search_tunables_grid 

117 

118 

119def test_grid_search(grid_search_opt: GridSearchOptimizer, 

120 grid_search_tunables: TunableGroups, 

121 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: 

122 """ 

123 Make sure that grid search optimizer initializes and works correctly. 

124 """ 

125 score = 1.0 

126 status = Status.SUCCEEDED 

127 suggestion = grid_search_opt.suggest() 

128 suggestion_dict = suggestion.get_param_values() 

129 default_config = grid_search_tunables.restore_defaults().get_param_values() 

130 

131 # First suggestion should be the defaults. 

132 assert suggestion.get_param_values() == default_config 

133 # But that shouldn't be the first element in the grid search. 

134 assert suggestion_dict != next(iter(grid_search_tunables_grid)) 

135 # The suggestion should no longer be in the pending_configs. 

136 assert suggestion_dict not in grid_search_opt.pending_configs 

137 # But it should be in the suggested_configs now (and the only one). 

138 assert list(grid_search_opt.suggested_configs) == [default_config] 

139 

140 # Register a score for that suggestion. 

141 grid_search_opt.register(suggestion, status, score) 

142 # Now it shouldn't be in the suggested_configs. 

143 assert len(list(grid_search_opt.suggested_configs)) == 0 

144 

145 grid_search_tunables_grid.remove(default_config) 

146 assert default_config not in grid_search_opt.pending_configs 

147 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

148 assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) 

149 

150 # The next suggestion should be a different element in the grid search. 

151 suggestion = grid_search_opt.suggest() 

152 suggestion_dict = suggestion.get_param_values() 

153 assert suggestion_dict != default_config 

154 assert suggestion_dict not in grid_search_opt.pending_configs 

155 assert suggestion_dict in grid_search_opt.suggested_configs 

156 grid_search_opt.register(suggestion, status, score) 

157 assert suggestion_dict not in grid_search_opt.pending_configs 

158 assert suggestion_dict not in grid_search_opt.suggested_configs 

159 

160 grid_search_tunables_grid.remove(suggestion.get_param_values()) 

161 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

162 assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) 

163 

164 # We consider not_converged as either having reached "max_suggestions" or an empty grid? 

165 

166 # Try to empty the rest of the grid. 

167 while grid_search_opt.not_converged(): 

168 suggestion = grid_search_opt.suggest() 

169 grid_search_opt.register(suggestion, status, score) 

170 

171 # The grid search should be empty now. 

172 assert not list(grid_search_opt.pending_configs) 

173 assert not list(grid_search_opt.suggested_configs) 

174 assert not grid_search_opt.not_converged() 

175 

176 # But if we still have iterations left, we should be able to suggest again by refilling the grid. 

177 assert grid_search_opt.current_iteration < grid_search_opt.max_iterations 

178 assert grid_search_opt.suggest() 

179 assert list(grid_search_opt.pending_configs) 

180 assert list(grid_search_opt.suggested_configs) 

181 assert grid_search_opt.not_converged() 

182 

183 # Try to finish the rest of our iterations by repeating the grid. 

184 while grid_search_opt.not_converged(): 

185 suggestion = grid_search_opt.suggest() 

186 grid_search_opt.register(suggestion, status, score) 

187 assert not grid_search_opt.not_converged() 

188 assert grid_search_opt.current_iteration >= grid_search_opt.max_iterations 

189 assert list(grid_search_opt.pending_configs) 

190 assert list(grid_search_opt.suggested_configs) 

191 

192 

193def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: 

194 """ 

195 Make sure that grid search optimizer works correctly when suggest and register 

196 are called out of order. 

197 """ 

198 score = 1.0 

199 status = Status.SUCCEEDED 

200 suggest_count = 10 

201 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

202 suggested_shuffled = suggested.copy() 

203 # Try to ensure the shuffled list is different. 

204 for _ in range(3): 

205 random.shuffle(suggested_shuffled) 

206 if suggested_shuffled != suggested: 

207 break 

208 assert suggested != suggested_shuffled 

209 

210 for suggestion in suggested_shuffled: 

211 suggestion_dict = suggestion.get_param_values() 

212 assert suggestion_dict not in grid_search_opt.pending_configs 

213 assert suggestion_dict in grid_search_opt.suggested_configs 

214 grid_search_opt.register(suggestion, status, score) 

215 assert suggestion_dict not in grid_search_opt.suggested_configs 

216 

217 best_score, best_config = grid_search_opt.get_best_observation() 

218 assert best_score == score 

219 

220 # test re-register with higher score 

221 best_suggestion = suggested_shuffled[0] 

222 best_suggestion_dict = best_suggestion.get_param_values() 

223 assert best_suggestion_dict not in grid_search_opt.pending_configs 

224 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

225 best_suggestion_score = score - 1 if grid_search_opt.direction == "min" else score + 1 

226 grid_search_opt.register(best_suggestion, status, best_suggestion_score) 

227 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

228 

229 best_score, best_config = grid_search_opt.get_best_observation() 

230 assert best_score == best_suggestion_score 

231 assert best_config == best_suggestion 

232 

233 # Check bulk register 

234 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

235 assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) 

236 assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) 

237 

238 # Those new suggestions also shouldn't be in the set of previously suggested configs. 

239 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) 

240 

241 grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], 

242 [{"score": score}] * len(suggested), 

243 [status] * len(suggested)) 

244 

245 assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) 

246 assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) 

247 

248 best_score, best_config = grid_search_opt.get_best_observation() 

249 assert best_score == best_suggestion_score 

250 assert best_config == best_suggestion