Coverage for mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py: 92%

91 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Grid search optimizer for mlos_bench. 

7""" 

8 

9import logging 

10 

11from typing import Dict, Iterable, Set, Optional, Sequence, Tuple, Union 

12 

13import numpy as np 

14import ConfigSpace 

15from ConfigSpace.util import generate_grid 

16 

17from mlos_bench.environments.status import Status 

18from mlos_bench.tunables.tunable import TunableValue 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer 

21from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values 

22from mlos_bench.services.base_service import Service 

23 

24_LOG = logging.getLogger(__name__) 

25 

26 

27class GridSearchOptimizer(TrackBestOptimizer): 

28 """ 

29 Grid search optimizer. 

30 """ 

31 

32 def __init__(self, 

33 tunables: TunableGroups, 

34 config: dict, 

35 global_config: Optional[dict] = None, 

36 service: Optional[Service] = None): 

37 super().__init__(tunables, config, global_config, service) 

38 

39 self._best_config: Optional[TunableGroups] = None 

40 self._best_score: Optional[float] = None 

41 

42 # Track the grid as a set of tuples of tunable values and reconstruct the 

43 # dicts as necessary. 

44 # Note: this is not the most efficient way to do this, but avoids 

45 # introducing a new data structure for hashable dicts. 

46 # See https://github.com/microsoft/MLOS/pull/690 for further discussion. 

47 

48 self._sanity_check() 

49 # The ordered set of pending configs that have not yet been suggested. 

50 self._config_keys, self._pending_configs = self._get_grid() 

51 assert self._pending_configs 

52 # A set of suggested configs that have not yet been registered. 

53 self._suggested_configs: Set[Tuple[TunableValue, ...]] = set() 

54 

55 def _sanity_check(self) -> None: 

56 size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) 

57 if size == np.inf: 

58 raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") 

59 if size > 10000: 

60 _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) 

61 if size > self._max_iter: 

62 _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) 

63 

64 def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: 

65 """ 

66 Gets a grid of configs to try. 

67 

68 Order is given by ConfigSpace, but preserved by dict ordering semantics. 

69 """ 

70 # Since we are using ConfigSpace to generate the grid, but only tracking the 

71 # values as (ordered) tuples, we also need to use its ordering on column 

72 # names instead of the order given by TunableGroups. 

73 configs = [ 

74 configspace_data_to_tunable_values(dict(config)) 

75 for config in 

76 generate_grid(self.config_space, { 

77 tunable.name: int(tunable.cardinality) 

78 for (tunable, _group) in self._tunables 

79 if tunable.quantization or tunable.type == "int" 

80 }) 

81 ] 

82 names = set(tuple(configs.keys()) for configs in configs) 

83 assert len(names) == 1 

84 return names.pop(), {tuple(configs.values()): None for configs in configs} 

85 

86 @property 

87 def pending_configs(self) -> Iterable[Dict[str, TunableValue]]: 

88 """ 

89 Gets the set of pending configs in this grid search optimizer. 

90 

91 Returns 

92 ------- 

93 Iterable[Dict[str, TunableValue]] 

94 """ 

95 # See NOTEs above. 

96 return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys()) 

97 

98 @property 

99 def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: 

100 """ 

101 Gets the set of configs that have been suggested but not yet registered. 

102 

103 Returns 

104 ------- 

105 Iterable[Dict[str, TunableValue]] 

106 """ 

107 # See NOTEs above. 

108 return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) 

109 

110 def bulk_register(self, 

111 configs: Sequence[dict], 

112 scores: Sequence[Optional[Dict[str, TunableValue]]], 

113 status: Optional[Sequence[Status]] = None) -> bool: 

114 if not super().bulk_register(configs, scores, status): 

115 return False 

116 if status is None: 

117 status = [Status.SUCCEEDED] * len(configs) 

118 for (params, score, trial_status) in zip(configs, scores, status): 

119 tunables = self._tunables.copy().assign(params) 

120 self.register(tunables, trial_status, score) 

121 if _LOG.isEnabledFor(logging.DEBUG): 

122 (best_score, _) = self.get_best_observation() 

123 _LOG.debug("Update end: %s = %s", self.target, best_score) 

124 return True 

125 

126 def suggest(self) -> TunableGroups: 

127 """ 

128 Generate the next grid search suggestion. 

129 """ 

130 tunables = super().suggest() 

131 if self._start_with_defaults: 

132 _LOG.info("Use default values for the first trial") 

133 self._start_with_defaults = False 

134 tunables = tunables.restore_defaults() 

135 # Need to index based on ConfigSpace dict ordering. 

136 default_config = dict(self.config_space.get_default_configuration()) 

137 assert tunables.get_param_values() == default_config 

138 # Move the default from the pending to the suggested set. 

139 default_config_values = tuple(default_config.values()) 

140 del self._pending_configs[default_config_values] 

141 self._suggested_configs.add(default_config_values) 

142 else: 

143 # Select the first item from the pending configs. 

144 if not self._pending_configs and self._iter <= self._max_iter: 

145 _LOG.info("No more pending configs to suggest. Restarting grid.") 

146 self._config_keys, self._pending_configs = self._get_grid() 

147 try: 

148 next_config_values = next(iter(self._pending_configs.keys())) 

149 except StopIteration as exc: 

150 raise ValueError("No more pending configs to suggest.") from exc 

151 next_config = dict(zip(self._config_keys, next_config_values)) 

152 tunables.assign(next_config) 

153 # Move it to the suggested set. 

154 self._suggested_configs.add(next_config_values) 

155 del self._pending_configs[next_config_values] 

156 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) 

157 return tunables 

158 

159 def register(self, tunables: TunableGroups, status: Status, 

160 score: Optional[Union[float, dict]] = None) -> Optional[float]: 

161 registered_score = super().register(tunables, status, score) 

162 try: 

163 config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) 

164 self._suggested_configs.remove(tuple(config.values())) 

165 except KeyError: 

166 _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) 

167 return registered_score 

168 

169 def not_converged(self) -> bool: 

170 if self._iter > self._max_iter: 

171 if bool(self._pending_configs): 

172 _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", 

173 len(self._pending_configs), list(self._pending_configs.keys())) 

174 return False 

175 return bool(self._pending_configs)