Coverage for mlos_bench/mlos_bench/optimizers/mock_optimizer.py: 97%

36 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Mock optimizer for mlos_bench. 

7 

8Mostly intended for testing and validation. This optimizer produces random suggestions. 

9The range of the suggestions can be controlled by a config. 

10 

11See the test cases or example json configs for more details. 

12""" 

13 

14import logging 

15import random 

16from collections.abc import Callable, Sequence 

17 

18from mlos_bench.environments.status import Status 

19from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer 

20from mlos_bench.services.base_service import Service 

21from mlos_bench.tunables.tunable import Tunable 

22from mlos_bench.tunables.tunable_groups import TunableGroups 

23from mlos_bench.tunables.tunable_types import TunableValue 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class MockOptimizer(TrackBestOptimizer): 

29 """Mock optimizer to test the Environment API.""" 

30 

31 def __init__( 

32 self, 

33 tunables: TunableGroups, 

34 config: dict, 

35 global_config: dict | None = None, 

36 service: Service | None = None, 

37 ): 

38 super().__init__(tunables, config, global_config, service) 

39 rnd = random.Random(self.seed) 

40 self._random: dict[str, Callable[[Tunable], TunableValue]] = { 

41 "categorical": lambda tunable: rnd.choice(tunable.categories), 

42 "float": lambda tunable: rnd.uniform(*tunable.range), 

43 "int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)), 

44 } 

45 

46 def bulk_register( 

47 self, 

48 configs: Sequence[dict], 

49 scores: Sequence[dict[str, TunableValue] | None], 

50 status: Sequence[Status] | None = None, 

51 ) -> bool: 

52 if not super().bulk_register(configs, scores, status): 

53 return False 

54 if status is None: 

55 status = [Status.SUCCEEDED] * len(configs) 

56 for params, score, trial_status in zip(configs, scores, status): 

57 tunables = self._tunables.copy().assign(params) 

58 self.register(tunables, trial_status, score) 

59 if _LOG.isEnabledFor(logging.DEBUG): 

60 (best_score, _) = self.get_best_observation() 

61 _LOG.debug("Bulk register END: %s = %s", self, best_score) 

62 return True 

63 

64 def suggest(self) -> TunableGroups: 

65 """Generate the next (random) suggestion.""" 

66 tunables = super().suggest() 

67 if self._start_with_defaults: 

68 _LOG.info("Use default tunable values") 

69 self._start_with_defaults = False 

70 else: 

71 for tunable, _group in tunables: 

72 tunable.value = self._random[tunable.type](tunable) 

73 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) 

74 return tunables