Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 97%
38 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Mock optimizer for mlos_bench."""
7import logging
8from abc import ABCMeta
9from typing import Dict, Optional, Tuple, Union
11from mlos_bench.environments.status import Status
12from mlos_bench.optimizers.base_optimizer import Optimizer
13from mlos_bench.services.base_service import Service
14from mlos_bench.tunables.tunable import TunableValue
15from mlos_bench.tunables.tunable_groups import TunableGroups
17_LOG = logging.getLogger(__name__)
20class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
21 """Base Optimizer class that keeps track of the best score and configuration."""
23 def __init__(
24 self,
25 tunables: TunableGroups,
26 config: dict,
27 global_config: Optional[dict] = None,
28 service: Optional[Service] = None,
29 ):
30 super().__init__(tunables, config, global_config, service)
31 self._best_config: Optional[TunableGroups] = None
32 self._best_score: Optional[Dict[str, float]] = None
34 def register(
35 self,
36 tunables: TunableGroups,
37 status: Status,
38 score: Optional[Dict[str, TunableValue]] = None,
39 ) -> Optional[Dict[str, float]]:
40 registered_score = super().register(tunables, status, score)
41 if status.is_succeeded() and self._is_better(registered_score):
42 self._best_score = registered_score
43 self._best_config = tunables.copy()
44 return registered_score
46 def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool:
47 """Compare the optimization scores to the best ones so far lexicographically."""
48 if self._best_score is None:
49 return True
50 assert registered_score is not None
51 for opt_target, best_score in self._best_score.items():
52 score = registered_score[opt_target]
53 if score < best_score:
54 return True
55 elif score > best_score:
56 return False
57 return False
59 def get_best_observation(
60 self,
61 ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]:
62 if self._best_score is None:
63 return (None, None)
64 score = self._get_scores(Status.SUCCEEDED, self._best_score)
65 assert score is not None
66 assert self._best_config is not None
67 return (score, self._best_config)