Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Scheduler-side environment to mock the benchmark results. 

7""" 

8 

9import random 

10import logging 

11from datetime import datetime 

12from typing import Dict, Optional, Tuple 

13 

14import numpy 

15 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.environments.status import Status 

18from mlos_bench.environments.base_environment import Environment 

19from mlos_bench.tunables import Tunable, TunableGroups, TunableValue 

20from mlos_bench.util import nullable 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class MockEnv(Environment): 

26 """ 

27 Scheduler-side environment to mock the benchmark results. 

28 """ 

29 

30 _NOISE_VAR = 0.2 

31 """Variance of the Gaussian noise added to the benchmark value.""" 

32 

33 def __init__(self, 

34 *, 

35 name: str, 

36 config: dict, 

37 global_config: Optional[dict] = None, 

38 tunables: Optional[TunableGroups] = None, 

39 service: Optional[Service] = None): 

40 """ 

41 Create a new environment that produces mock benchmark data. 

42 

43 Parameters 

44 ---------- 

45 name: str 

46 Human-readable name of the environment. 

47 config : dict 

48 Free-format dictionary that contains the benchmark environment configuration. 

49 global_config : dict 

50 Free-format dictionary of global parameters (e.g., security credentials) 

51 to be mixed in into the "const_args" section of the local config. 

52 Optional arguments are `seed`, `range`, and `metrics`. 

53 tunables : TunableGroups 

54 A collection of tunable parameters for *all* environments. 

55 service: Service 

56 An optional service object. Not used by this class. 

57 """ 

58 super().__init__(name=name, config=config, global_config=global_config, 

59 tunables=tunables, service=service) 

60 seed = self.config.get("seed") 

61 self._random = nullable(random.Random, seed) 

62 self._range = self.config.get("range") 

63 self._metrics = self.config.get("metrics", ["score"]) 

64 self._is_ready = True 

65 

66 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: 

67 """ 

68 Produce mock benchmark data for one experiment. 

69 

70 Returns 

71 ------- 

72 (status, timestamp, output) : (Status, datetime, dict) 

73 3-tuple of (Status, timestamp, output) values, where `output` is a dict 

74 with the results or None if the status is not COMPLETED. 

75 The keys of the `output` dict are the names of the metrics 

76 specified in the config; by default it's just one metric 

77 named "score". All output metrics have the same value. 

78 """ 

79 (status, timestamp, _) = result = super().run() 

80 if not status.is_ready(): 

81 return result 

82 

83 # Simple convex function of all tunable parameters. 

84 score = numpy.mean(numpy.square([ 

85 self._normalized(tunable) for (tunable, _group) in self._tunable_params 

86 ])) 

87 

88 # Add noise and shift the benchmark value from [0, 1] to a given range. 

89 noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 

90 score = numpy.clip(score + noise, 0, 1) 

91 if self._range: 

92 score = self._range[0] + score * (self._range[1] - self._range[0]) 

93 

94 return (Status.SUCCEEDED, timestamp, {metric: score for metric in self._metrics}) 

95 

96 @staticmethod 

97 def _normalized(tunable: Tunable) -> float: 

98 """ 

99 Get the NORMALIZED value of a tunable. 

100 That is, map current value to the [0, 1] range. 

101 """ 

102 val = None 

103 if tunable.is_categorical: 

104 val = (tunable.categories.index(tunable.category) / 

105 float(len(tunable.categories) - 1)) 

106 elif tunable.is_numerical: 

107 val = ((tunable.numerical_value - tunable.range[0]) / 

108 float(tunable.range[1] - tunable.range[0])) 

109 else: 

110 raise ValueError("Invalid parameter type: " + tunable.type) 

111 # Explicitly clip the value in case of numerical errors. 

112 ret: float = numpy.clip(val, 0, 1) 

113 return ret