Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 94%

52 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Scheduler-side environment to mock the benchmark results.""" 

6 

7import logging 

8import random 

9from datetime import datetime 

10from typing import Any 

11 

12import numpy 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.environments.status import Status 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.tunables.tunable import Tunable 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19from mlos_bench.tunables.tunable_types import TunableValue 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class MockEnv(Environment): 

25 """Scheduler-side environment to mock the benchmark results.""" 

26 

27 _NOISE_VAR = 0.2 

28 """Variance of the Gaussian noise added to the benchmark value.""" 

29 

30 def __init__( # pylint: disable=too-many-arguments 

31 self, 

32 *, 

33 name: str, 

34 config: dict, 

35 global_config: dict | None = None, 

36 tunables: TunableGroups | None = None, 

37 service: Service | None = None, 

38 ): 

39 """ 

40 Create a new environment that produces mock benchmark data. 

41 

42 Parameters 

43 ---------- 

44 name: str 

45 Human-readable name of the environment. 

46 config : dict 

47 Free-format dictionary that contains the benchmark environment configuration. 

48 global_config : dict 

49 Free-format dictionary of global parameters (e.g., security credentials) 

50 to be mixed in into the "const_args" section of the local config. 

51 Optional arguments are `mock_env_seed`, `mock_env_range`, and `mock_env_metrics`. 

52 Set `mock_env_seed` to -1 for deterministic behavior, 0 for default randomness. 

53 tunables : TunableGroups 

54 A collection of tunable parameters for *all* environments. 

55 service: Service 

56 An optional service object. Not used by this class. 

57 """ 

58 super().__init__( 

59 name=name, 

60 config=config, 

61 global_config=global_config, 

62 tunables=tunables, 

63 service=service, 

64 ) 

65 seed = int(self.config.get("mock_env_seed", -1)) 

66 self._run_random = random.Random(seed or None) if seed >= 0 else None 

67 self._status_random = random.Random(seed or None) if seed >= 0 else None 

68 self._range: tuple[int, int] | None = self.config.get("mock_env_range") 

69 self._metrics: list[str] | None = self.config.get("mock_env_metrics", ["score"]) 

70 self._is_ready = True 

71 

72 def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue]: 

73 # Simple convex function of all tunable parameters. 

74 score = numpy.mean( 

75 numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params]) 

76 ) 

77 

78 # Add noise and shift the benchmark value from [0, 1] to a given range. 

79 noise = rand.gauss(0, self._NOISE_VAR) if rand else 0 

80 score = numpy.clip(score + noise, 0, 1) 

81 if self._range: 

82 score = self._range[0] + score * (self._range[1] - self._range[0]) 

83 

84 return {metric: float(score) for metric in self._metrics or []} 

85 

86 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]: 

87 """ 

88 Produce mock benchmark data for one experiment. 

89 

90 Returns 

91 ------- 

92 (status, timestamp, output) : (Status, datetime.datetime, dict) 

93 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

94 with the results or None if the status is not COMPLETED. 

95 The keys of the `output` dict are the names of the metrics 

96 specified in the config; by default it's just one metric 

97 named "score". All output metrics have the same value. 

98 """ 

99 (status, timestamp, _) = result = super().run() 

100 if not status.is_ready(): 

101 return result 

102 metrics = self._produce_metrics(self._run_random) 

103 return (Status.SUCCEEDED, timestamp, metrics) 

104 

105 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]: 

106 """ 

107 Produce mock benchmark status telemetry for one experiment. 

108 

109 Returns 

110 ------- 

111 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list) 

112 3-tuple of (benchmark status, timestamp, telemetry) values. 

113 `timestamp` is UTC time stamp of the status; it's current time by default. 

114 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets. 

115 """ 

116 (status, timestamp, _) = result = super().status() 

117 if not status.is_ready(): 

118 return result 

119 metrics = self._produce_metrics(self._status_random) 

120 return ( 

121 # FIXME: this causes issues if we report RUNNING instead of READY 

122 Status.READY, 

123 timestamp, 

124 [(timestamp, metric, score) for (metric, score) in metrics.items()], 

125 ) 

126 

127 @staticmethod 

128 def _normalized(tunable: Tunable) -> float: 

129 """ 

130 Get the NORMALIZED value of a tunable. 

131 

132 That is, map current value to the [0, 1] range. 

133 """ 

134 val = None 

135 if tunable.is_categorical: 

136 val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1) 

137 elif tunable.is_numerical: 

138 val = (tunable.numerical_value - tunable.range[0]) / float( 

139 tunable.range[1] - tunable.range[0] 

140 ) 

141 else: 

142 raise ValueError("Invalid parameter type: " + tunable.type) 

143 # Explicitly clip the value in case of numerical errors. 

144 ret: float = numpy.clip(val, 0, 1) 

145 return ret