Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%
121 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for grid search mlos_bench optimizer.
7"""
9from typing import Dict, List
11import itertools
12import math
13import random
15import pytest
17from mlos_bench.environments.status import Status
18from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer
19from mlos_bench.tunables.tunable import TunableValue
20from mlos_bench.tunables.tunable_groups import TunableGroups
23# pylint: disable=redefined-outer-name
25@pytest.fixture
26def grid_search_tunables_config() -> dict:
27 """
28 Test fixture for grid search optimizer tunables config.
29 """
30 return {
31 "grid": {
32 "cost": 1,
33 "params": {
34 "cat": {
35 "type": "categorical",
36 "values": ["a", "b", "c"],
37 "default": "a",
38 },
39 "int": {
40 "type": "int",
41 "range": [1, 3],
42 "default": 2,
43 },
44 "float": {
45 "type": "float",
46 "range": [0, 1],
47 "default": 0.5,
48 "quantization": 0.25,
49 },
50 },
51 },
52 }
55@pytest.fixture
56def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]:
57 """
58 Test fixture for grid from tunable groups.
59 Used to check that the grids are the same (ignoring order).
60 """
61 tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None]
62 tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None)
63 return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values))
66@pytest.fixture
67def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups:
68 """
69 Test fixture for grid search optimizer tunables.
70 """
71 return TunableGroups(grid_search_tunables_config)
74@pytest.fixture
75def grid_search_opt(grid_search_tunables: TunableGroups,
76 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer:
77 """
78 Test fixture for grid search optimizer.
79 """
80 assert len(grid_search_tunables) == 3
81 # Test the convergence logic by controlling the number of iterations to be not a
82 # multiple of the number of elements in the grid.
83 max_iterations = len(grid_search_tunables_grid) * 2 - 3
84 return GridSearchOptimizer(tunables=grid_search_tunables, config={
85 "max_suggestions": max_iterations,
86 "optimization_targets": {"score": "max"},
87 })
90def test_grid_search_grid(grid_search_opt: GridSearchOptimizer,
91 grid_search_tunables: TunableGroups,
92 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None:
93 """
94 Make sure that grid search optimizer initializes and works correctly.
95 """
96 # Check the size.
97 expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables)
98 assert expected_grid_size > len(grid_search_tunables)
99 assert len(grid_search_tunables_grid) == expected_grid_size
100 # Check for specific example configs inclusion.
101 expected_config_example: Dict[str, TunableValue] = {
102 "cat": "a",
103 "int": 2,
104 "float": 0.75,
105 }
106 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs)
107 assert expected_config_example in grid_search_tunables_grid
108 assert expected_config_example in grid_search_opt_pending_configs
109 # Check the rest of the contents.
110 # Note: ConfigSpace param name vs TunableGroup parameter name order is not
111 # consistent, so we need to full dict comparison.
112 assert len(grid_search_opt_pending_configs) == expected_grid_size
113 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs)
114 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid)
115 # Order is less relevant to us, so we'll just check that the sets are the same.
116 # assert grid_search_opt.pending_configs == grid_search_tunables_grid
119def test_grid_search(grid_search_opt: GridSearchOptimizer,
120 grid_search_tunables: TunableGroups,
121 grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None:
122 """
123 Make sure that grid search optimizer initializes and works correctly.
124 """
125 score = 1.0
126 status = Status.SUCCEEDED
127 suggestion = grid_search_opt.suggest()
128 suggestion_dict = suggestion.get_param_values()
129 default_config = grid_search_tunables.restore_defaults().get_param_values()
131 # First suggestion should be the defaults.
132 assert suggestion.get_param_values() == default_config
133 # But that shouldn't be the first element in the grid search.
134 assert suggestion_dict != next(iter(grid_search_tunables_grid))
135 # The suggestion should no longer be in the pending_configs.
136 assert suggestion_dict not in grid_search_opt.pending_configs
137 # But it should be in the suggested_configs now (and the only one).
138 assert list(grid_search_opt.suggested_configs) == [default_config]
140 # Register a score for that suggestion.
141 grid_search_opt.register(suggestion, status, score)
142 # Now it shouldn't be in the suggested_configs.
143 assert len(list(grid_search_opt.suggested_configs)) == 0
145 grid_search_tunables_grid.remove(default_config)
146 assert default_config not in grid_search_opt.pending_configs
147 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
148 assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid)
150 # The next suggestion should be a different element in the grid search.
151 suggestion = grid_search_opt.suggest()
152 suggestion_dict = suggestion.get_param_values()
153 assert suggestion_dict != default_config
154 assert suggestion_dict not in grid_search_opt.pending_configs
155 assert suggestion_dict in grid_search_opt.suggested_configs
156 grid_search_opt.register(suggestion, status, score)
157 assert suggestion_dict not in grid_search_opt.pending_configs
158 assert suggestion_dict not in grid_search_opt.suggested_configs
160 grid_search_tunables_grid.remove(suggestion.get_param_values())
161 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
162 assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid)
164 # We consider not_converged as either having reached "max_suggestions" or an empty grid?
166 # Try to empty the rest of the grid.
167 while grid_search_opt.not_converged():
168 suggestion = grid_search_opt.suggest()
169 grid_search_opt.register(suggestion, status, score)
171 # The grid search should be empty now.
172 assert not list(grid_search_opt.pending_configs)
173 assert not list(grid_search_opt.suggested_configs)
174 assert not grid_search_opt.not_converged()
176 # But if we still have iterations left, we should be able to suggest again by refilling the grid.
177 assert grid_search_opt.current_iteration < grid_search_opt.max_iterations
178 assert grid_search_opt.suggest()
179 assert list(grid_search_opt.pending_configs)
180 assert list(grid_search_opt.suggested_configs)
181 assert grid_search_opt.not_converged()
183 # Try to finish the rest of our iterations by repeating the grid.
184 while grid_search_opt.not_converged():
185 suggestion = grid_search_opt.suggest()
186 grid_search_opt.register(suggestion, status, score)
187 assert not grid_search_opt.not_converged()
188 assert grid_search_opt.current_iteration >= grid_search_opt.max_iterations
189 assert list(grid_search_opt.pending_configs)
190 assert list(grid_search_opt.suggested_configs)
193def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
194 """
195 Make sure that grid search optimizer works correctly when suggest and register
196 are called out of order.
197 """
198 score = 1.0
199 status = Status.SUCCEEDED
200 suggest_count = 10
201 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
202 suggested_shuffled = suggested.copy()
203 # Try to ensure the shuffled list is different.
204 for _ in range(3):
205 random.shuffle(suggested_shuffled)
206 if suggested_shuffled != suggested:
207 break
208 assert suggested != suggested_shuffled
210 for suggestion in suggested_shuffled:
211 suggestion_dict = suggestion.get_param_values()
212 assert suggestion_dict not in grid_search_opt.pending_configs
213 assert suggestion_dict in grid_search_opt.suggested_configs
214 grid_search_opt.register(suggestion, status, score)
215 assert suggestion_dict not in grid_search_opt.suggested_configs
217 best_score, best_config = grid_search_opt.get_best_observation()
218 assert best_score == score
220 # test re-register with higher score
221 best_suggestion = suggested_shuffled[0]
222 best_suggestion_dict = best_suggestion.get_param_values()
223 assert best_suggestion_dict not in grid_search_opt.pending_configs
224 assert best_suggestion_dict not in grid_search_opt.suggested_configs
225 best_suggestion_score = score - 1 if grid_search_opt.direction == "min" else score + 1
226 grid_search_opt.register(best_suggestion, status, best_suggestion_score)
227 assert best_suggestion_dict not in grid_search_opt.suggested_configs
229 best_score, best_config = grid_search_opt.get_best_observation()
230 assert best_score == best_suggestion_score
231 assert best_config == best_suggestion
233 # Check bulk register
234 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
235 assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested)
236 assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested)
238 # Those new suggestions also shouldn't be in the set of previously suggested configs.
239 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested)
241 grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested],
242 [{"score": score}] * len(suggested),
243 [status] * len(suggested))
245 assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested)
246 assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested)
248 best_score, best_config = grid_search_opt.get_best_observation()
249 assert best_score == best_suggestion_score
250 assert best_config == best_suggestion