Coverage for mlos_bench/mlos_bench/optimizers/track_best_optimizer.py: 96%

25 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Mock optimizer for mlos_bench. 

7""" 

8 

9import logging 

10from abc import ABCMeta 

11from typing import Optional, Tuple, Union 

12 

13from mlos_bench.environments.status import Status 

14from mlos_bench.tunables.tunable_groups import TunableGroups 

15 

16from mlos_bench.optimizers.base_optimizer import Optimizer 

17from mlos_bench.services.base_service import Service 

18 

19_LOG = logging.getLogger(__name__) 

20 

21 

22class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): 

23 """ 

24 Base Optimizer class that keeps track of the best score and configuration. 

25 """ 

26 

27 def __init__(self, 

28 tunables: TunableGroups, 

29 config: dict, 

30 global_config: Optional[dict] = None, 

31 service: Optional[Service] = None): 

32 super().__init__(tunables, config, global_config, service) 

33 self._best_config: Optional[TunableGroups] = None 

34 self._best_score: Optional[float] = None 

35 

36 def register(self, tunables: TunableGroups, status: Status, 

37 score: Optional[Union[float, dict]] = None) -> Optional[float]: 

38 registered_score = super().register(tunables, status, score) 

39 if status.is_succeeded() and ( 

40 self._best_score is None or (registered_score is not None and registered_score < self._best_score) 

41 ): 

42 self._best_score = registered_score 

43 self._best_config = tunables.copy() 

44 return registered_score 

45 

46 def get_best_observation(self) -> Union[Tuple[float, TunableGroups], Tuple[None, None]]: 

47 if self._best_score is None: 

48 return (None, None) 

49 assert self._best_config is not None 

50 return (self._best_score * self._opt_sign, self._best_config)