Coverage for mlos_bench/mlos_bench/environments/mock_env.py: 94%
52 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side environment to mock the benchmark results."""
7import logging
8import random
9from datetime import datetime
10from typing import Any
12import numpy
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.services.base_service import Service
17from mlos_bench.tunables.tunable import Tunable
18from mlos_bench.tunables.tunable_groups import TunableGroups
19from mlos_bench.tunables.tunable_types import TunableValue
21_LOG = logging.getLogger(__name__)
24class MockEnv(Environment):
25 """Scheduler-side environment to mock the benchmark results."""
27 _NOISE_VAR = 0.2
28 """Variance of the Gaussian noise added to the benchmark value."""
30 def __init__( # pylint: disable=too-many-arguments
31 self,
32 *,
33 name: str,
34 config: dict,
35 global_config: dict | None = None,
36 tunables: TunableGroups | None = None,
37 service: Service | None = None,
38 ):
39 """
40 Create a new environment that produces mock benchmark data.
42 Parameters
43 ----------
44 name: str
45 Human-readable name of the environment.
46 config : dict
47 Free-format dictionary that contains the benchmark environment configuration.
48 global_config : dict
49 Free-format dictionary of global parameters (e.g., security credentials)
50 to be mixed in into the "const_args" section of the local config.
51 Optional arguments are `mock_env_seed`, `mock_env_range`, and `mock_env_metrics`.
52 Set `mock_env_seed` to -1 for deterministic behavior, 0 for default randomness.
53 tunables : TunableGroups
54 A collection of tunable parameters for *all* environments.
55 service: Service
56 An optional service object. Not used by this class.
57 """
58 super().__init__(
59 name=name,
60 config=config,
61 global_config=global_config,
62 tunables=tunables,
63 service=service,
64 )
65 seed = int(self.config.get("mock_env_seed", -1))
66 self._run_random = random.Random(seed or None) if seed >= 0 else None
67 self._status_random = random.Random(seed or None) if seed >= 0 else None
68 self._range: tuple[int, int] | None = self.config.get("mock_env_range")
69 self._metrics: list[str] | None = self.config.get("mock_env_metrics", ["score"])
70 self._is_ready = True
72 def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue]:
73 # Simple convex function of all tunable parameters.
74 score = numpy.mean(
75 numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params])
76 )
78 # Add noise and shift the benchmark value from [0, 1] to a given range.
79 noise = rand.gauss(0, self._NOISE_VAR) if rand else 0
80 score = numpy.clip(score + noise, 0, 1)
81 if self._range:
82 score = self._range[0] + score * (self._range[1] - self._range[0])
84 return {metric: float(score) for metric in self._metrics or []}
86 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
87 """
88 Produce mock benchmark data for one experiment.
90 Returns
91 -------
92 (status, timestamp, output) : (Status, datetime.datetime, dict)
93 3-tuple of (Status, timestamp, output) values, where `output` is a dict
94 with the results or None if the status is not COMPLETED.
95 The keys of the `output` dict are the names of the metrics
96 specified in the config; by default it's just one metric
97 named "score". All output metrics have the same value.
98 """
99 (status, timestamp, _) = result = super().run()
100 if not status.is_ready():
101 return result
102 metrics = self._produce_metrics(self._run_random)
103 return (Status.SUCCEEDED, timestamp, metrics)
105 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
106 """
107 Produce mock benchmark status telemetry for one experiment.
109 Returns
110 -------
111 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
112 3-tuple of (benchmark status, timestamp, telemetry) values.
113 `timestamp` is UTC time stamp of the status; it's current time by default.
114 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
115 """
116 (status, timestamp, _) = result = super().status()
117 if not status.is_ready():
118 return result
119 metrics = self._produce_metrics(self._status_random)
120 return (
121 # FIXME: this causes issues if we report RUNNING instead of READY
122 Status.READY,
123 timestamp,
124 [(timestamp, metric, score) for (metric, score) in metrics.items()],
125 )
127 @staticmethod
128 def _normalized(tunable: Tunable) -> float:
129 """
130 Get the NORMALIZED value of a tunable.
132 That is, map current value to the [0, 1] range.
133 """
134 val = None
135 if tunable.is_categorical:
136 val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1)
137 elif tunable.is_numerical:
138 val = (tunable.numerical_value - tunable.range[0]) / float(
139 tunable.range[1] - tunable.range[0]
140 )
141 else:
142 raise ValueError("Invalid parameter type: " + tunable.type)
143 # Explicitly clip the value in case of numerical errors.
144 ret: float = numpy.clip(val, 0, 1)
145 return ret