Coverage for mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py: 92%
88 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Grid search optimizer for mlos_bench."""
7import logging
8from typing import Dict, Iterable, Optional, Sequence, Set, Tuple
10import ConfigSpace
11import numpy as np
12from ConfigSpace.util import generate_grid
14from mlos_bench.environments.status import Status
15from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
16from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
17from mlos_bench.services.base_service import Service
18from mlos_bench.tunables.tunable import TunableValue
19from mlos_bench.tunables.tunable_groups import TunableGroups
21_LOG = logging.getLogger(__name__)
24class GridSearchOptimizer(TrackBestOptimizer):
25 """Grid search optimizer."""
27 def __init__(
28 self,
29 tunables: TunableGroups,
30 config: dict,
31 global_config: Optional[dict] = None,
32 service: Optional[Service] = None,
33 ):
34 super().__init__(tunables, config, global_config, service)
36 # Track the grid as a set of tuples of tunable values and reconstruct the
37 # dicts as necessary.
38 # Note: this is not the most efficient way to do this, but avoids
39 # introducing a new data structure for hashable dicts.
40 # See https://github.com/microsoft/MLOS/pull/690 for further discussion.
42 self._sanity_check()
43 # The ordered set of pending configs that have not yet been suggested.
44 self._config_keys, self._pending_configs = self._get_grid()
45 assert self._pending_configs
46 # A set of suggested configs that have not yet been registered.
47 self._suggested_configs: Set[Tuple[TunableValue, ...]] = set()
49 def _sanity_check(self) -> None:
50 size = np.prod([tunable.cardinality or np.inf for (tunable, _group) in self._tunables])
51 if size == np.inf:
52 raise ValueError(
53 f"Unquantized tunables are not supported for grid search: {self._tunables}"
54 )
55 if size > 10000:
56 _LOG.warning(
57 "Large number %d of config points requested for grid search: %s",
58 size,
59 self._tunables,
60 )
61 if size > self._max_suggestions:
62 _LOG.warning(
63 "Grid search size %d, is greater than max iterations %d",
64 size,
65 self._max_suggestions,
66 )
68 def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]:
69 """
70 Gets a grid of configs to try.
72 Order is given by ConfigSpace, but preserved by dict ordering semantics.
73 """
74 # Since we are using ConfigSpace to generate the grid, but only tracking the
75 # values as (ordered) tuples, we also need to use its ordering on column
76 # names instead of the order given by TunableGroups.
77 configs = [
78 configspace_data_to_tunable_values(dict(config))
79 for config in generate_grid(
80 self.config_space,
81 {
82 tunable.name: tunable.cardinality or 0 # mypy wants an int
83 for (tunable, _group) in self._tunables
84 if tunable.is_numerical and tunable.cardinality
85 },
86 )
87 ]
88 names = set(tuple(configs.keys()) for configs in configs)
89 assert len(names) == 1
90 return names.pop(), {tuple(configs.values()): None for configs in configs}
92 @property
93 def pending_configs(self) -> Iterable[Dict[str, TunableValue]]:
94 """
95 Gets the set of pending configs in this grid search optimizer.
97 Returns
98 -------
99 Iterable[Dict[str, TunableValue]]
100 """
101 # See NOTEs above.
102 return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys())
104 @property
105 def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]:
106 """
107 Gets the set of configs that have been suggested but not yet registered.
109 Returns
110 -------
111 Iterable[Dict[str, TunableValue]]
112 """
113 # See NOTEs above.
114 return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
116 def bulk_register(
117 self,
118 configs: Sequence[dict],
119 scores: Sequence[Optional[Dict[str, TunableValue]]],
120 status: Optional[Sequence[Status]] = None,
121 ) -> bool:
122 if not super().bulk_register(configs, scores, status):
123 return False
124 if status is None:
125 status = [Status.SUCCEEDED] * len(configs)
126 for params, score, trial_status in zip(configs, scores, status):
127 tunables = self._tunables.copy().assign(params)
128 self.register(tunables, trial_status, score)
129 if _LOG.isEnabledFor(logging.DEBUG):
130 (best_score, _) = self.get_best_observation()
131 _LOG.debug("Update END: %s = %s", self, best_score)
132 return True
134 def suggest(self) -> TunableGroups:
135 """Generate the next grid search suggestion."""
136 tunables = super().suggest()
137 if self._start_with_defaults:
138 _LOG.info("Use default values for the first trial")
139 self._start_with_defaults = False
140 tunables = tunables.restore_defaults()
141 # Need to index based on ConfigSpace dict ordering.
142 default_config = dict(self.config_space.get_default_configuration())
143 assert tunables.get_param_values() == default_config
144 # Move the default from the pending to the suggested set.
145 default_config_values = tuple(default_config.values())
146 del self._pending_configs[default_config_values]
147 self._suggested_configs.add(default_config_values)
148 else:
149 # Select the first item from the pending configs.
150 if not self._pending_configs and self._iter <= self._max_suggestions:
151 _LOG.info("No more pending configs to suggest. Restarting grid.")
152 self._config_keys, self._pending_configs = self._get_grid()
153 try:
154 next_config_values = next(iter(self._pending_configs.keys()))
155 except StopIteration as exc:
156 raise ValueError("No more pending configs to suggest.") from exc
157 next_config = dict(zip(self._config_keys, next_config_values))
158 tunables.assign(next_config)
159 # Move it to the suggested set.
160 self._suggested_configs.add(next_config_values)
161 del self._pending_configs[next_config_values]
162 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
163 return tunables
165 def register(
166 self,
167 tunables: TunableGroups,
168 status: Status,
169 score: Optional[Dict[str, TunableValue]] = None,
170 ) -> Optional[Dict[str, float]]:
171 registered_score = super().register(tunables, status, score)
172 try:
173 config = dict(
174 ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())
175 )
176 self._suggested_configs.remove(tuple(config.values()))
177 except KeyError:
178 _LOG.warning(
179 (
180 "Attempted to remove missing config "
181 "(previously registered?) from suggested set: %s"
182 ),
183 tunables,
184 )
185 return registered_score
187 def not_converged(self) -> bool:
188 if self._iter > self._max_suggestions:
189 if bool(self._pending_configs):
190 _LOG.warning(
191 "Exceeded max iterations, but still have %d pending configs: %s",
192 len(self._pending_configs),
193 list(self._pending_configs.keys()),
194 )
195 return False
196 return bool(self._pending_configs)