Coverage for mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py: 92%
91 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Grid search optimizer for mlos_bench.
7"""
9import logging
11from typing import Dict, Iterable, Set, Optional, Sequence, Tuple, Union
13import numpy as np
14import ConfigSpace
15from ConfigSpace.util import generate_grid
17from mlos_bench.environments.status import Status
18from mlos_bench.tunables.tunable import TunableValue
19from mlos_bench.tunables.tunable_groups import TunableGroups
20from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
21from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values
22from mlos_bench.services.base_service import Service
24_LOG = logging.getLogger(__name__)
27class GridSearchOptimizer(TrackBestOptimizer):
28 """
29 Grid search optimizer.
30 """
32 def __init__(self,
33 tunables: TunableGroups,
34 config: dict,
35 global_config: Optional[dict] = None,
36 service: Optional[Service] = None):
37 super().__init__(tunables, config, global_config, service)
39 self._best_config: Optional[TunableGroups] = None
40 self._best_score: Optional[float] = None
42 # Track the grid as a set of tuples of tunable values and reconstruct the
43 # dicts as necessary.
44 # Note: this is not the most efficient way to do this, but avoids
45 # introducing a new data structure for hashable dicts.
46 # See https://github.com/microsoft/MLOS/pull/690 for further discussion.
48 self._sanity_check()
49 # The ordered set of pending configs that have not yet been suggested.
50 self._config_keys, self._pending_configs = self._get_grid()
51 assert self._pending_configs
52 # A set of suggested configs that have not yet been registered.
53 self._suggested_configs: Set[Tuple[TunableValue, ...]] = set()
55 def _sanity_check(self) -> None:
56 size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables])
57 if size == np.inf:
58 raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}")
59 if size > 10000:
60 _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables)
61 if size > self._max_iter:
62 _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter)
64 def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]:
65 """
66 Gets a grid of configs to try.
68 Order is given by ConfigSpace, but preserved by dict ordering semantics.
69 """
70 # Since we are using ConfigSpace to generate the grid, but only tracking the
71 # values as (ordered) tuples, we also need to use its ordering on column
72 # names instead of the order given by TunableGroups.
73 configs = [
74 configspace_data_to_tunable_values(dict(config))
75 for config in
76 generate_grid(self.config_space, {
77 tunable.name: int(tunable.cardinality)
78 for (tunable, _group) in self._tunables
79 if tunable.quantization or tunable.type == "int"
80 })
81 ]
82 names = set(tuple(configs.keys()) for configs in configs)
83 assert len(names) == 1
84 return names.pop(), {tuple(configs.values()): None for configs in configs}
86 @property
87 def pending_configs(self) -> Iterable[Dict[str, TunableValue]]:
88 """
89 Gets the set of pending configs in this grid search optimizer.
91 Returns
92 -------
93 Iterable[Dict[str, TunableValue]]
94 """
95 # See NOTEs above.
96 return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys())
98 @property
99 def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]:
100 """
101 Gets the set of configs that have been suggested but not yet registered.
103 Returns
104 -------
105 Iterable[Dict[str, TunableValue]]
106 """
107 # See NOTEs above.
108 return (dict(zip(self._config_keys, config)) for config in self._suggested_configs)
110 def bulk_register(self,
111 configs: Sequence[dict],
112 scores: Sequence[Optional[Dict[str, TunableValue]]],
113 status: Optional[Sequence[Status]] = None) -> bool:
114 if not super().bulk_register(configs, scores, status):
115 return False
116 if status is None:
117 status = [Status.SUCCEEDED] * len(configs)
118 for (params, score, trial_status) in zip(configs, scores, status):
119 tunables = self._tunables.copy().assign(params)
120 self.register(tunables, trial_status, score)
121 if _LOG.isEnabledFor(logging.DEBUG):
122 (best_score, _) = self.get_best_observation()
123 _LOG.debug("Update end: %s = %s", self.target, best_score)
124 return True
126 def suggest(self) -> TunableGroups:
127 """
128 Generate the next grid search suggestion.
129 """
130 tunables = super().suggest()
131 if self._start_with_defaults:
132 _LOG.info("Use default values for the first trial")
133 self._start_with_defaults = False
134 tunables = tunables.restore_defaults()
135 # Need to index based on ConfigSpace dict ordering.
136 default_config = dict(self.config_space.get_default_configuration())
137 assert tunables.get_param_values() == default_config
138 # Move the default from the pending to the suggested set.
139 default_config_values = tuple(default_config.values())
140 del self._pending_configs[default_config_values]
141 self._suggested_configs.add(default_config_values)
142 else:
143 # Select the first item from the pending configs.
144 if not self._pending_configs and self._iter <= self._max_iter:
145 _LOG.info("No more pending configs to suggest. Restarting grid.")
146 self._config_keys, self._pending_configs = self._get_grid()
147 try:
148 next_config_values = next(iter(self._pending_configs.keys()))
149 except StopIteration as exc:
150 raise ValueError("No more pending configs to suggest.") from exc
151 next_config = dict(zip(self._config_keys, next_config_values))
152 tunables.assign(next_config)
153 # Move it to the suggested set.
154 self._suggested_configs.add(next_config_values)
155 del self._pending_configs[next_config_values]
156 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
157 return tunables
159 def register(self, tunables: TunableGroups, status: Status,
160 score: Optional[Union[float, dict]] = None) -> Optional[float]:
161 registered_score = super().register(tunables, status, score)
162 try:
163 config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()))
164 self._suggested_configs.remove(tuple(config.values()))
165 except KeyError:
166 _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables)
167 return registered_score
169 def not_converged(self) -> bool:
170 if self._iter > self._max_iter:
171 if bool(self._pending_configs):
172 _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s",
173 len(self._pending_configs), list(self._pending_configs.keys()))
174 return False
175 return bool(self._pending_configs)