Coverage for mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py: 92%
90 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Grid search Optimizer for mlos_bench.
8Grid search is a simple optimizer that exhaustively searches the configuration space.
10To do this it generates a grid of configurations to try, and then suggests them one by one.
12Therefore, the number of configurations to try is the product of the
13:py:attr:`~mlos_bench.tunables.tunable.Tunable.cardinality` of each of the
14:py:mod:`~mlos_bench.tunables`.
15(i.e., non :py:attr:`quantized <mlos_bench.tunables.tunable.Tunable.quantization_bins>`
16tunables are not supported).
18Examples
19--------
20Load tunables from a JSON string.
21Note: normally these would be automatically loaded from the
22:py:mod:`~mlos_bench.environments.base_environment.Environment`'s
23``include_tunables`` config parameter.
25>>> import json5 as json
26>>> from mlos_bench.environments.status import Status
27>>> from mlos_bench.services.config_persistence import ConfigPersistenceService
28>>> service = ConfigPersistenceService()
29>>> json_config = '''
30... {
31... "group_1": {
32... "cost": 1,
33... "params": {
34... "colors": {
35... "type": "categorical",
36... "values": ["red", "blue", "green"],
37... "default": "green",
38... },
39... "int_param": {
40... "type": "int",
41... "range": [1, 3],
42... "default": 2,
43... },
44... "float_param": {
45... "type": "float",
46... "range": [0, 1],
47... "default": 0.5,
48... // Quantize the range into 3 bins
49... "quantization_bins": 3,
50... }
51... }
52... }
53... }
54... '''
55>>> tunables = service.load_tunables(jsons=[json_config])
56>>> # Check the defaults:
57>>> tunables.get_param_values()
58{'colors': 'green', 'int_param': 2, 'float_param': 0.5}
60Now create a :py:class:`.GridSearchOptimizer` from a JSON config string.
62>>> optimizer_json_config = '''
63... {
64... "class": "mlos_bench.optimizers.grid_search_optimizer.GridSearchOptimizer",
65... "description": "GridSearchOptimizer",
66... "config": {
67... "max_suggestions": 100,
68... "optimization_targets": {"score": "max"},
69... "start_with_defaults": true
70... }
71... }
72... '''
73>>> config = json.loads(optimizer_json_config)
74>>> grid_search_optimizer = service.build_optimizer(
75... tunables=tunables,
76... service=service,
77... config=config,
78... )
79>>> # Should have 3 values for each of the 3 tunables
80>>> len(list(grid_search_optimizer.pending_configs))
8127
82>>> next(grid_search_optimizer.pending_configs)
83{'colors': 'red', 'float_param': 0, 'int_param': 1}
85Here are some examples of suggesting and registering configurations.
87>>> suggested_config_1 = grid_search_optimizer.suggest()
88>>> # Default should be suggested first, per json config.
89>>> suggested_config_1.get_param_values()
90{'colors': 'green', 'int_param': 2, 'float_param': 0.5}
91>>> # Get another suggestion.
92>>> # Note that multiple suggestions can be pending prior to
93>>> # registering their scores, supporting parallel trial execution.
94>>> suggested_config_2 = grid_search_optimizer.suggest()
95>>> suggested_config_2.get_param_values()
96{'colors': 'red', 'int_param': 1, 'float_param': 0.0}
97>>> # Register some scores.
98>>> # Note: Maximization problems track negative scores to produce a minimization problem.
99>>> grid_search_optimizer.register(suggested_config_1, Status.SUCCEEDED, {"score": 42})
100{'score': -42.0}
101>>> grid_search_optimizer.register(suggested_config_2, Status.SUCCEEDED, {"score": 7})
102{'score': -7.0}
103>>> (best_score, best_config) = grid_search_optimizer.get_best_observation()
104>>> best_score
105{'score': 42.0}
106>>> assert best_config == suggested_config_1
107"""
109import logging
110from collections.abc import Iterable, Sequence
112import ConfigSpace
113import numpy as np
114from ConfigSpace.util import generate_grid
116from mlos_bench.environments.status import Status
117from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
118from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
119from mlos_bench.services.base_service import Service
120from mlos_bench.tunables.tunable_groups import TunableGroups
121from mlos_bench.tunables.tunable_types import TunableValue
123_LOG = logging.getLogger(__name__)
126class GridSearchOptimizer(TrackBestOptimizer):
127 """
128 Grid search optimizer.
130 See :py:mod:`above <mlos_bench.optimizers.grid_search_optimizer>` for more details.
131 """
133 MAX_CONFIGS = 10000
134 """Maximum number of configurations to enumerate."""
136 def __init__(
137 self,
138 tunables: TunableGroups,
139 config: dict,
140 global_config: dict | None = None,
141 service: Service | None = None,
142 ):
143 super().__init__(tunables, config, global_config, service)
145 # Track the grid as a set of tuples of tunable values and reconstruct the
146 # dicts as necessary.
147 # Note: this is not the most efficient way to do this, but avoids
148 # introducing a new data structure for hashable dicts.
149 # See https://github.com/microsoft/MLOS/pull/690 for further discussion.
151 self._sanity_check()
152 # The ordered set of pending configs that have not yet been suggested.
153 self._config_keys, self._pending_configs = self._get_grid()
154 assert self._pending_configs
155 # A set of suggested configs that have not yet been registered.
156 self._suggested_configs: set[tuple[TunableValue, ...]] = set()
158 def _sanity_check(self) -> None:
159 size = np.prod([tunable.cardinality or np.inf for (tunable, _group) in self._tunables])
160 if size == np.inf:
161 raise ValueError(
162 f"Unquantized tunables are not supported for grid search: {self._tunables}"
163 )
164 if size > self.MAX_CONFIGS:
165 _LOG.warning(
166 "Large number %d of config points requested for grid search: %s",
167 size,
168 self._tunables,
169 )
170 if size > self._max_suggestions:
171 _LOG.warning(
172 "Grid search size %d, is greater than max iterations %d",
173 size,
174 self._max_suggestions,
175 )
177 def _get_grid(self) -> tuple[tuple[str, ...], dict[tuple[TunableValue, ...], None]]:
178 """
179 Gets a grid of configs to try.
181 Order is given by ConfigSpace, but preserved by dict ordering semantics.
182 """
183 # Since we are using ConfigSpace to generate the grid, but only tracking the
184 # values as (ordered) tuples, we also need to use its ordering on column
185 # names instead of the order given by TunableGroups.
186 configs = [
187 configspace_data_to_tunable_values(dict(config))
188 for config in generate_grid(
189 self.config_space,
190 {
191 tunable.name: tunable.cardinality or 0 # mypy wants an int
192 for (tunable, _group) in self._tunables
193 if tunable.is_numerical and tunable.cardinality
194 },
195 )
196 ]
197 names = {tuple(configs.keys()) for configs in configs}
198 assert len(names) == 1
199 return names.pop(), {tuple(configs.values()): None for configs in configs}
201 @property
202 def pending_configs(self) -> Iterable[dict[str, TunableValue]]:
203 """
204 Gets the set of pending configs in this grid search optimizer.
206 Returns
207 -------
208 Iterable[dict[str, TunableValue]]
209 """
210 # See NOTEs above.
211 return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys())
213 @property
214 def suggested_configs(self) -> Iterable[dict[str, TunableValue]]:
215 """
216 Gets the set of configs that have been suggested but not yet registered.
218 Returns
219 -------
220 Iterable[dict[str, TunableValue]]
221 """
222 # See NOTEs above.
223 return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
225 def bulk_register(
226 self,
227 configs: Sequence[dict],
228 scores: Sequence[dict[str, TunableValue] | None],
229 status: Sequence[Status] | None = None,
230 ) -> bool:
231 if not super().bulk_register(configs, scores, status):
232 return False
233 if status is None:
234 status = [Status.SUCCEEDED] * len(configs)
235 for params, score, trial_status in zip(configs, scores, status):
236 tunables = self._tunables.copy().assign(params)
237 self.register(tunables, trial_status, score)
238 if _LOG.isEnabledFor(logging.DEBUG):
239 (best_score, _) = self.get_best_observation()
240 _LOG.debug("Update END: %s = %s", self, best_score)
241 return True
243 def suggest(self) -> TunableGroups:
244 """Generate the next grid search suggestion."""
245 tunables = super().suggest()
246 if self._start_with_defaults:
247 _LOG.info("Use default values for the first trial")
248 self._start_with_defaults = False
249 tunables = tunables.restore_defaults()
250 # Need to index based on ConfigSpace dict ordering.
251 default_config = dict(self.config_space.get_default_configuration())
252 assert tunables.get_param_values() == default_config
253 # Move the default from the pending to the suggested set.
254 default_config_values = tuple(default_config.values())
255 del self._pending_configs[default_config_values]
256 self._suggested_configs.add(default_config_values)
257 else:
258 # Select the first item from the pending configs.
259 if not self._pending_configs and self._iter <= self._max_suggestions:
260 _LOG.info("No more pending configs to suggest. Restarting grid.")
261 self._config_keys, self._pending_configs = self._get_grid()
262 try:
263 next_config_values = next(iter(self._pending_configs.keys()))
264 except StopIteration as exc:
265 raise ValueError("No more pending configs to suggest.") from exc
266 next_config = dict(zip(self._config_keys, next_config_values))
267 tunables.assign(next_config)
268 # Move it to the suggested set.
269 self._suggested_configs.add(next_config_values)
270 del self._pending_configs[next_config_values]
271 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
272 return tunables
274 def register(
275 self,
276 tunables: TunableGroups,
277 status: Status,
278 score: dict[str, TunableValue] | None = None,
279 ) -> dict[str, float] | None:
280 registered_score = super().register(tunables, status, score)
281 try:
282 config = dict(
283 ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())
284 )
285 self._suggested_configs.remove(tuple(config.values()))
286 except KeyError:
287 _LOG.warning(
288 (
289 "Attempted to remove missing config "
290 "(previously registered?) from suggested set: %s"
291 ),
292 tunables,
293 )
294 return registered_score
296 def not_converged(self) -> bool:
297 if self._iter > self._max_suggestions:
298 if bool(self._pending_configs):
299 _LOG.warning(
300 "Exceeded max iterations, but still have %d pending configs: %s",
301 len(self._pending_configs),
302 list(self._pending_configs.keys()),
303 )
304 return False
305 return bool(self._pending_configs)