Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%

36 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Mock optimizer for mlos_bench. 

7""" 

8 

9import random 

10import logging 

11 

12from typing import Callable, Dict, Optional, Sequence 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.tunables.tunable import Tunable, TunableValue 

16from mlos_bench.tunables.tunable_groups import TunableGroups 

17 

18from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer 

19from mlos_bench.services.base_service import Service 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class MockOptimizer(TrackBestOptimizer): 

25 """ 

26 Mock optimizer to test the Environment API. 

27 """ 

28 

29 def __init__(self, 

30 tunables: TunableGroups, 

31 config: dict, 

32 global_config: Optional[dict] = None, 

33 service: Optional[Service] = None): 

34 super().__init__(tunables, config, global_config, service) 

35 rnd = random.Random(self.seed) 

36 self._random: Dict[str, Callable[[Tunable], TunableValue]] = { 

37 "categorical": lambda tunable: rnd.choice(tunable.categories), 

38 "float": lambda tunable: rnd.uniform(*tunable.range), 

39 "int": lambda tunable: rnd.randint(*tunable.range), 

40 } 

41 

42 def bulk_register(self, 

43 configs: Sequence[dict], 

44 scores: Sequence[Optional[Dict[str, TunableValue]]], 

45 status: Optional[Sequence[Status]] = None) -> bool: 

46 if not super().bulk_register(configs, scores, status): 

47 return False 

48 if status is None: 

49 status = [Status.SUCCEEDED] * len(configs) 

50 for (params, score, trial_status) in zip(configs, scores, status): 

51 tunables = self._tunables.copy().assign(params) 

52 self.register(tunables, trial_status, score) 

53 if _LOG.isEnabledFor(logging.DEBUG): 

54 (best_score, _) = self.get_best_observation() 

55 _LOG.debug("Bulk register end: %s = %s", self.target, best_score) 

56 return True 

57 

58 def suggest(self) -> TunableGroups: 

59 """ 

60 Generate the next (random) suggestion. 

61 """ 

62 tunables = super().suggest() 

63 if self._start_with_defaults: 

64 _LOG.info("Use default values for the first trial") 

65 self._start_with_defaults = False 

66 else: 

67 for (tunable, _group) in tunables: 

68 tunable.value = self._random[tunable.type](tunable) 

69 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) 

70 return tunables