Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Mock optimizer for mlos_bench.""" 

6 

7import logging 

8import random 

9from typing import Callable, Dict, Optional, Sequence 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.tunables.tunable import Tunable, TunableValue 

15from mlos_bench.tunables.tunable_groups import TunableGroups 

16 

17_LOG = logging.getLogger(__name__) 

18 

19 

20class MockOptimizer(TrackBestOptimizer): 

21 """Mock optimizer to test the Environment API.""" 

22 

23 def __init__( 

24 self, 

25 tunables: TunableGroups, 

26 config: dict, 

27 global_config: Optional[dict] = None, 

28 service: Optional[Service] = None, 

29 ): 

30 super().__init__(tunables, config, global_config, service) 

31 rnd = random.Random(self.seed) 

32 self._random: Dict[str, Callable[[Tunable], TunableValue]] = { 

33 "categorical": lambda tunable: rnd.choice(tunable.categories), 

34 "float": lambda tunable: rnd.uniform(*tunable.range), 

35 "int": lambda tunable: rnd.randint(*tunable.range), 

36 } 

37 

38 def bulk_register( 

39 self, 

40 configs: Sequence[dict], 

41 scores: Sequence[Optional[Dict[str, TunableValue]]], 

42 status: Optional[Sequence[Status]] = None, 

43 ) -> bool: 

44 if not super().bulk_register(configs, scores, status): 

45 return False 

46 if status is None: 

47 status = [Status.SUCCEEDED] * len(configs) 

48 for params, score, trial_status in zip(configs, scores, status): 

49 tunables = self._tunables.copy().assign(params) 

50 self.register(tunables, trial_status, score) 

51 if _LOG.isEnabledFor(logging.DEBUG): 

52 (best_score, _) = self.get_best_observation() 

53 _LOG.debug("Bulk register END: %s = %s", self, best_score) 

54 return True 

55 

56 def suggest(self) -> TunableGroups: 

57 """Generate the next (random) suggestion.""" 

58 tunables = super().suggest() 

59 if self._start_with_defaults: 

60 _LOG.info("Use default tunable values") 

61 self._start_with_defaults = False 

62 else: 

63 for tunable, _group in tunables: 

64 tunable.value = self._random[tunable.type](tunable) 

65 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) 

66 return tunables