Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 94%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Scheduler-side environment to mock the benchmark results.""" 

6 

7import logging 

8import random 

9from datetime import datetime 

10from typing import Any, Dict, List, Optional, Tuple 

11 

12import numpy 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.environments.status import Status 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.tunables.tunable import Tunable, TunableValue 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class MockEnv(Environment): 

24 """Scheduler-side environment to mock the benchmark results.""" 

25 

26 _NOISE_VAR = 0.2 

27 """Variance of the Gaussian noise added to the benchmark value.""" 

28 

29 def __init__( # pylint: disable=too-many-arguments 

30 self, 

31 *, 

32 name: str, 

33 config: dict, 

34 global_config: Optional[dict] = None, 

35 tunables: Optional[TunableGroups] = None, 

36 service: Optional[Service] = None, 

37 ): 

38 """ 

39 Create a new environment that produces mock benchmark data. 

40 

41 Parameters 

42 ---------- 

43 name: str 

44 Human-readable name of the environment. 

45 config : dict 

46 Free-format dictionary that contains the benchmark environment configuration. 

47 global_config : dict 

48 Free-format dictionary of global parameters (e.g., security credentials) 

49 to be mixed in into the "const_args" section of the local config. 

50 Optional arguments are `mock_env_seed`, `mock_env_range`, and `mock_env_metrics`. 

51 Set `mock_env_seed` to -1 for deterministic behavior, 0 for default randomness. 

52 tunables : TunableGroups 

53 A collection of tunable parameters for *all* environments. 

54 service: Service 

55 An optional service object. Not used by this class. 

56 """ 

57 super().__init__( 

58 name=name, 

59 config=config, 

60 global_config=global_config, 

61 tunables=tunables, 

62 service=service, 

63 ) 

64 seed = int(self.config.get("mock_env_seed", -1)) 

65 self._run_random = random.Random(seed or None) if seed >= 0 else None 

66 self._status_random = random.Random(seed or None) if seed >= 0 else None 

67 self._range = self.config.get("mock_env_range") 

68 self._metrics = self.config.get("mock_env_metrics", ["score"]) 

69 self._is_ready = True 

70 

71 def _produce_metrics(self, rand: Optional[random.Random]) -> Dict[str, TunableValue]: 

72 # Simple convex function of all tunable parameters. 

73 score = numpy.mean( 

74 numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params]) 

75 ) 

76 

77 # Add noise and shift the benchmark value from [0, 1] to a given range. 

78 noise = rand.gauss(0, self._NOISE_VAR) if rand else 0 

79 score = numpy.clip(score + noise, 0, 1) 

80 if self._range: 

81 score = self._range[0] + score * (self._range[1] - self._range[0]) 

82 

83 return {metric: score for metric in self._metrics} 

84 

85 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

86 """ 

87 Produce mock benchmark data for one experiment. 

88 

89 Returns 

90 ------- 

91 (status, timestamp, output) : (Status, datetime.datetime, dict) 

92 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

93 with the results or None if the status is not COMPLETED. 

94 The keys of the `output` dict are the names of the metrics 

95 specified in the config; by default it's just one metric 

96 named "score". All output metrics have the same value. 

97 """ 

98 (status, timestamp, _) = result = super().run() 

99 if not status.is_ready(): 

100 return result 

101 metrics = self._produce_metrics(self._run_random) 

102 return (Status.SUCCEEDED, timestamp, metrics) 

103 

104 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: 

105 """ 

106 Produce mock benchmark status telemetry for one experiment. 

107 

108 Returns 

109 ------- 

110 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

111 3-tuple of (benchmark status, timestamp, telemetry) values. 

112 `timestamp` is UTC time stamp of the status; it's current time by default. 

113 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

114 """ 

115 (status, timestamp, _) = result = super().status() 

116 if not status.is_ready(): 

117 return result 

118 metrics = self._produce_metrics(self._status_random) 

119 return ( 

120 # FIXME: this causes issues if we report RUNNING instead of READY 

121 Status.READY, 

122 timestamp, 

123 [(timestamp, metric, score) for (metric, score) in metrics.items()], 

124 ) 

125 

126 @staticmethod 

127 def _normalized(tunable: Tunable) -> float: 

128 """ 

129 Get the NORMALIZED value of a tunable. 

130 

131 That is, map current value to the [0, 1] range. 

132 """ 

133 val = None 

134 if tunable.is_categorical: 

135 val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1) 

136 elif tunable.is_numerical: 

137 val = (tunable.numerical_value - tunable.range[0]) / float( 

138 tunable.range[1] - tunable.range[0] 

139 ) 

140 else: 

141 raise ValueError("Invalid parameter type: " + tunable.type) 

142 # Explicitly clip the value in case of numerical errors. 

143 ret: float = numpy.clip(val, 0, 1) 

144 return ret