Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 96%
25 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Mock optimizer for mlos_bench.
7"""
9import logging
10from abc import ABCMeta
11from typing import Optional, Tuple, Union
13from mlos_bench.environments.status import Status
14from mlos_bench.tunables.tunable_groups import TunableGroups
16from mlos_bench.optimizers.base_optimizer import Optimizer
17from mlos_bench.services.base_service import Service
19_LOG = logging.getLogger(__name__)
22class TrackBestOptimizer(Optimizer, metaclass=ABCMeta):
23 """
24 Base Optimizer class that keeps track of the best score and configuration.
25 """
27 def __init__(self,
28 tunables: TunableGroups,
29 config: dict,
30 global_config: Optional[dict] = None,
31 service: Optional[Service] = None):
32 super().__init__(tunables, config, global_config, service)
33 self._best_config: Optional[TunableGroups] = None
34 self._best_score: Optional[float] = None
36 def register(self, tunables: TunableGroups, status: Status,
37 score: Optional[Union[float, dict]] = None) -> Optional[float]:
38 registered_score = super().register(tunables, status, score)
39 if status.is_succeeded() and (
40 self._best_score is None or (registered_score is not None and registered_score < self._best_score)
41 ):
42 self._best_score = registered_score
43 self._best_config = tunables.copy()
44 return registered_score
46 def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]:
47 if self._best_score is None:
48 return (None, None)
49 assert self._best_config is not None
50 return (self._best_score * self._opt_sign, self._best_config)