Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%
36 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Mock optimizer for mlos_bench.
7"""
9import random
10import logging
12from typing import Callable, Dict, Optional, Sequence
14from mlos_bench.environments.status import Status
15from mlos_bench.tunables.tunable import Tunable, TunableValue
16from mlos_bench.tunables.tunable_groups import TunableGroups
18from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
19from mlos_bench.services.base_service import Service
21_LOG = logging.getLogger(__name__)
24class MockOptimizer(TrackBestOptimizer):
25 """
26 Mock optimizer to test the Environment API.
27 """
29 def __init__(self,
30 tunables: TunableGroups,
31 config: dict,
32 global_config: Optional[dict] = None,
33 service: Optional[Service] = None):
34 super().__init__(tunables, config, global_config, service)
35 rnd = random.Random(self.seed)
36 self._random: Dict[str, Callable[[Tunable], TunableValue]] = {
37 "categorical": lambda tunable: rnd.choice(tunable.categories),
38 "float": lambda tunable: rnd.uniform(*tunable.range),
39 "int": lambda tunable: rnd.randint(*tunable.range),
40 }
42 def bulk_register(self,
43 configs: Sequence[dict],
44 scores: Sequence[Optional[Dict[str, TunableValue]]],
45 status: Optional[Sequence[Status]] = None) -> bool:
46 if not super().bulk_register(configs, scores, status):
47 return False
48 if status is None:
49 status = [Status.SUCCEEDED] * len(configs)
50 for (params, score, trial_status) in zip(configs, scores, status):
51 tunables = self._tunables.copy().assign(params)
52 self.register(tunables, trial_status, score)
53 if _LOG.isEnabledFor(logging.DEBUG):
54 (best_score, _) = self.get_best_observation()
55 _LOG.debug("Bulk register end: %s = %s", self.target, best_score)
56 return True
58 def suggest(self) -> TunableGroups:
59 """
60 Generate the next (random) suggestion.
61 """
62 tunables = super().suggest()
63 if self._start_with_defaults:
64 _LOG.info("Use default values for the first trial")
65 self._start_with_defaults = False
66 else:
67 for (tunable, _group) in tunables:
68 tunable.value = self._random[tunable.type](tunable)
69 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
70 return tunables