Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 94%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side environment to mock the benchmark results."""
7import logging
8import random
9from datetime import datetime
10from typing import Any, Dict, List, Optional, Tuple
12import numpy
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.services.base_service import Service
17from mlos_bench.tunables.tunable import Tunable, TunableValue
18from mlos_bench.tunables.tunable_groups import TunableGroups
20_LOG = logging.getLogger(__name__)
23class MockEnv(Environment):
24 """Scheduler-side environment to mock the benchmark results."""
26 _NOISE_VAR = 0.2
27 """Variance of the Gaussian noise added to the benchmark value."""
29 def __init__( # pylint: disable=too-many-arguments
30 self,
31 *,
32 name: str,
33 config: dict,
34 global_config: Optional[dict] = None,
35 tunables: Optional[TunableGroups] = None,
36 service: Optional[Service] = None,
37 ):
38 """
39 Create a new environment that produces mock benchmark data.
41 Parameters
42 ----------
43 name: str
44 Human-readable name of the environment.
45 config : dict
46 Free-format dictionary that contains the benchmark environment configuration.
47 global_config : dict
48 Free-format dictionary of global parameters (e.g., security credentials)
49 to be mixed in into the "const_args" section of the local config.
50 Optional arguments are `mock_env_seed`, `mock_env_range`, and `mock_env_metrics`.
51 Set `mock_env_seed` to -1 for deterministic behavior, 0 for default randomness.
52 tunables : TunableGroups
53 A collection of tunable parameters for *all* environments.
54 service: Service
55 An optional service object. Not used by this class.
56 """
57 super().__init__(
58 name=name,
59 config=config,
60 global_config=global_config,
61 tunables=tunables,
62 service=service,
63 )
64 seed = int(self.config.get("mock_env_seed", -1))
65 self._run_random = random.Random(seed or None) if seed >= 0 else None
66 self._status_random = random.Random(seed or None) if seed >= 0 else None
67 self._range = self.config.get("mock_env_range")
68 self._metrics = self.config.get("mock_env_metrics", ["score"])
69 self._is_ready = True
71 def _produce_metrics(self, rand: Optional[random.Random]) -> Dict[str, TunableValue]:
72 # Simple convex function of all tunable parameters.
73 score = numpy.mean(
74 numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
75 )
77 # Add noise and shift the benchmark value from [0, 1] to a given range.
78 noise = rand.gauss(0, self._NOISE_VAR) if rand else 0
79 score = numpy.clip(score + noise, 0, 1)
80 if self._range:
81 score = self._range[0] + score * (self._range[1] - self._range[0])
83 return {metric: score for metric in self._metrics}
85 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
86 """
87 Produce mock benchmark data for one experiment.
89 Returns
90 -------
91 (status, timestamp, output) : (Status, datetime.datetime, dict)
92 3-tuple of (Status, timestamp, output) values, where `output` is a dict
93 with the results or None if the status is not COMPLETED.
94 The keys of the `output` dict are the names of the metrics
95 specified in the config; by default it's just one metric
96 named "score". All output metrics have the same value.
97 """
98 (status, timestamp, _) = result = super().run()
99 if not status.is_ready():
100 return result
101 metrics = self._produce_metrics(self._run_random)
102 return (Status.SUCCEEDED, timestamp, metrics)
104 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
105 """
106 Produce mock benchmark status telemetry for one experiment.
108 Returns
109 -------
110 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
111 3-tuple of (benchmark status, timestamp, telemetry) values.
112 `timestamp` is UTC time stamp of the status; it's current time by default.
113 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
114 """
115 (status, timestamp, _) = result = super().status()
116 if not status.is_ready():
117 return result
118 metrics = self._produce_metrics(self._status_random)
119 return (
120 # FIXME: this causes issues if we report RUNNING instead of READY
121 Status.READY,
122 timestamp,
123 [(timestamp, metric, score) for (metric, score) in metrics.items()],
124 )
126 @staticmethod
127 def _normalized(tunable: Tunable) -> float:
128 """
129 Get the NORMALIZED value of a tunable.
131 That is, map current value to the [0, 1] range.
132 """
133 val = None
134 if tunable.is_categorical:
135 val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
136 elif tunable.is_numerical:
137 val = (tunable.numerical_value - tunable.range[0]) / float(
138 tunable.range[1] - tunable.range[0]
139 )
140 else:
141 raise ValueError("Invalid parameter type: " + tunable.type)
142 # Explicitly clip the value in case of numerical errors.
143 ret: float = numpy.clip(val, 0, 1)
144 return ret