Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%
36 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Mock optimizer for mlos_bench.
8Mostly intended for testing and validation. This optimizer produces random suggestions.
9The range of the suggestions can be controlled by a config.
11See the test cases or example json configs for more details.
12"""
14import logging
15import random
16from collections.abc import Callable, Sequence
18from mlos_bench.environments.status import Status
19from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer
20from mlos_bench.services.base_service import Service
21from mlos_bench.tunables.tunable import Tunable
22from mlos_bench.tunables.tunable_groups import TunableGroups
23from mlos_bench.tunables.tunable_types import TunableValue
25_LOG = logging.getLogger(__name__)
28class MockOptimizer(TrackBestOptimizer):
29 """Mock optimizer to test the Environment API."""
31 def __init__(
32 self,
33 tunables: TunableGroups,
34 config: dict,
35 global_config: dict | None = None,
36 service: Service | None = None,
37 ):
38 super().__init__(tunables, config, global_config, service)
39 rnd = random.Random(self.seed)
40 self._random: dict[str, Callable[[Tunable], TunableValue]] = {
41 "categorical": lambda tunable: rnd.choice(tunable.categories),
42 "float": lambda tunable: rnd.uniform(*tunable.range),
43 "int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)),
44 }
46 def bulk_register(
47 self,
48 configs: Sequence[dict],
49 scores: Sequence[dict[str, TunableValue] | None],
50 status: Sequence[Status] | None = None,
51 ) -> bool:
52 if not super().bulk_register(configs, scores, status):
53 return False
54 if status is None:
55 status = [Status.SUCCEEDED] * len(configs)
56 for params, score, trial_status in zip(configs, scores, status):
57 tunables = self._tunables.copy().assign(params)
58 self.register(tunables, trial_status, score)
59 if _LOG.isEnabledFor(logging.DEBUG):
60 (best_score, _) = self.get_best_observation()
61 _LOG.debug("Bulk register END: %s = %s", self, best_score)
62 return True
64 def suggest(self) -> TunableGroups:
65 """Generate the next (random) suggestion."""
66 tunables = super().suggest()
67 if self._start_with_defaults:
68 _LOG.info("Use default tunable values")
69 self._start_with_defaults = False
70 else:
71 for tunable, _group in tunables:
72 tunable.value = self._random[tunable.type](tunable)
73 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables)
74 return tunables