Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%
105 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base class for the service mix-ins.
7"""
9import json
10import logging
12from types import TracebackType
13from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
14from typing_extensions import Literal
16from mlos_bench.config.schemas import ConfigSchema
17from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
18from mlos_bench.util import instantiate_from_config
20_LOG = logging.getLogger(__name__)
23class Service:
24 """
25 An abstract base of all Environment Services and used to build up mix-ins.
26 """
28 @classmethod
29 def new(cls,
30 class_name: str,
31 config: Optional[Dict[str, Any]] = None,
32 global_config: Optional[Dict[str, Any]] = None,
33 parent: Optional["Service"] = None) -> "Service":
34 """
35 Factory method for a new service with a given config.
37 Parameters
38 ----------
39 class_name: str
40 FQN of a Python class to instantiate, e.g.,
41 "mlos_bench.services.remote.azure.AzureVMService".
42 Must be derived from the `Service` class.
43 config : dict
44 Free-format dictionary that contains the service configuration.
45 It will be passed as a constructor parameter of the class
46 specified by `class_name`.
47 global_config : dict
48 Free-format dictionary of global parameters.
49 parent : Service
50 A parent service that can provide mixin functions.
52 Returns
53 -------
54 svc : Service
55 An instance of the `Service` class initialized with `config`.
56 """
57 assert issubclass(cls, Service)
58 return instantiate_from_config(cls, class_name, config, global_config, parent)
60 def __init__(self,
61 config: Optional[Dict[str, Any]] = None,
62 global_config: Optional[Dict[str, Any]] = None,
63 parent: Optional["Service"] = None,
64 methods: Union[Dict[str, Callable], List[Callable], None] = None):
65 """
66 Create a new service with a given config.
68 Parameters
69 ----------
70 config : dict
71 Free-format dictionary that contains the service configuration.
72 It will be passed as a constructor parameter of the class
73 specified by `class_name`.
74 global_config : dict
75 Free-format dictionary of global parameters.
76 parent : Service
77 An optional parent service that can provide mixin functions.
78 methods : Union[Dict[str, Callable], List[Callable], None]
79 New methods to register with the service.
80 """
81 self.config = config or {}
82 self._validate_json_config(self.config)
83 self._parent = parent
84 self._service_methods: Dict[str, Callable] = {}
85 self._services: Set[Service] = set()
86 self._service_contexts: List[Service] = []
87 self._in_context = False
89 if parent:
90 self.register(parent.export())
91 if methods:
92 self.register(methods)
94 self._config_loader_service: SupportsConfigLoading
95 if parent and isinstance(parent, SupportsConfigLoading):
96 self._config_loader_service = parent
98 if _LOG.isEnabledFor(logging.DEBUG):
99 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
100 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
101 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
103 @staticmethod
104 def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None],
105 local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]:
106 """
107 Merge methods from the external caller with the local ones.
108 This function is usually called by the derived class constructor
109 just before invoking the constructor of the base class.
110 """
111 if isinstance(local_methods, dict):
112 local_methods = local_methods.copy()
113 else:
114 local_methods = {svc.__name__: svc for svc in local_methods}
116 if not ext_methods:
117 return local_methods
119 if not isinstance(ext_methods, dict):
120 ext_methods = {svc.__name__: svc for svc in ext_methods}
122 local_methods.update(ext_methods)
123 return local_methods
125 def __enter__(self) -> "Service":
126 """
127 Enter the Service mix-in context.
129 Calls the _enter_context() method of all the Services registered under this one.
130 """
131 if self._in_context:
132 # Multiple environments can share the same Service, so we need to
133 # add a check and make this a re-entrant Service context.
134 assert self._service_contexts
135 assert all(svc._in_context for svc in self._services)
136 return self
137 self._service_contexts = [svc._enter_context() for svc in self._services]
138 self._in_context = True
139 return self
141 def __exit__(self, ex_type: Optional[Type[BaseException]],
142 ex_val: Optional[BaseException],
143 ex_tb: Optional[TracebackType]) -> Literal[False]:
144 """
145 Exit the Service mix-in context.
147 Calls the _exit_context() method of all the Services registered under this one.
148 """
149 if not self._in_context:
150 # Multiple environments can share the same Service, so we need to
151 # add a check and make this a re-entrant Service context.
152 assert not self._service_contexts
153 assert all(not svc._in_context for svc in self._services)
154 return False
155 ex_throw = None
156 for svc in reversed(self._service_contexts):
157 try:
158 svc._exit_context(ex_type, ex_val, ex_tb)
159 # pylint: disable=broad-exception-caught
160 except Exception as ex:
161 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex)
162 ex_throw = ex
163 self._service_contexts = []
164 if ex_throw:
165 raise ex_throw
166 self._in_context = False
167 return False
169 def _enter_context(self) -> "Service":
170 """
171 Enters the context for this particular Service instance.
173 Called by the base __enter__ method of the Service class so it can be
174 used with mix-ins and overridden by subclasses.
175 """
176 assert not self._in_context
177 self._in_context = True
178 return self
180 def _exit_context(self, ex_type: Optional[Type[BaseException]],
181 ex_val: Optional[BaseException],
182 ex_tb: Optional[TracebackType]) -> Literal[False]:
183 """
184 Exits the context for this particular Service instance.
186 Called by the base __enter__ method of the Service class so it can be
187 used with mix-ins and overridden by subclasses.
188 """
189 # pylint: disable=unused-argument
190 assert self._in_context
191 self._in_context = False
192 return False
194 def _validate_json_config(self, config: dict) -> None:
195 """
196 Reconstructs a basic json config that this class might have been
197 instantiated from in order to validate configs provided outside the
198 file loading mechanism.
199 """
200 if self.__class__ == Service:
201 # Skip over the case where instantiate a bare base Service class in order to build up a mix-in.
202 assert config == {}
203 return
204 json_config: dict = {
205 "class": self.__class__.__module__ + "." + self.__class__.__name__,
206 }
207 if config:
208 json_config["config"] = config
209 ConfigSchema.SERVICE.validate(json_config)
211 def __repr__(self) -> str:
212 return f"{self.__class__.__name__}@{hex(id(self))}"
214 def pprint(self) -> str:
215 """
216 Produce a human-readable string listing all public methods of the service.
217 """
218 return f"{self} ::\n" + "\n".join(
219 f' "{key}": {getattr(val, "__self__", "stand-alone")}'
220 for (key, val) in self._service_methods.items()
221 )
223 @property
224 def config_loader_service(self) -> SupportsConfigLoading:
225 """
226 Return a config loader service.
228 Returns
229 -------
230 config_loader_service : SupportsConfigLoading
231 A config loader service.
232 """
233 return self._config_loader_service
235 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None:
236 """
237 Register new mix-in services.
239 Parameters
240 ----------
241 services : dict or list
242 A dictionary of string -> function pairs.
243 """
244 if not isinstance(services, dict):
245 services = {svc.__name__: svc for svc in services}
247 self._service_methods.update(services)
248 self.__dict__.update(self._service_methods)
250 if _LOG.isEnabledFor(logging.DEBUG):
251 _LOG.debug("Added methods to: %s", self.pprint())
253 # In order to get a list of all child contexts, we need to look at only
254 # the bound methods that were not overridden by another mixin.
255 # Then we inspect the internally bound __self__ variable to discover
256 # which Service instance that method belongs too.
257 # To do this we also
259 # All service loading must happen prior to entering a context.
260 assert not self._in_context
261 assert not self._service_contexts
262 self._services = {
263 # Enumerate the Services that are bound to this instance in the
264 # order they were added.
265 # Unfortunately, by creating a set, we may destroy the ability to
266 # preserve the context enter/exit order, but hopefully it doesn't
267 # matter.
268 svc_method.__self__ for _, svc_method in self._service_methods.items()
269 # Note: some methods are actually stand alone functions, so we need
270 # to filter them out.
271 if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service)
272 }
274 def export(self) -> Dict[str, Callable]:
275 """
276 Return a dictionary of functions available in this service.
278 Returns
279 -------
280 services : dict
281 A dictionary of string -> function pairs.
282 """
283 if _LOG.isEnabledFor(logging.DEBUG):
284 _LOG.debug("Export methods from: %s", self.pprint())
286 return self._service_methods