Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the service mix-ins."""
7import json
8import logging
9from types import TracebackType
10from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union
12from mlos_bench.config.schemas import ConfigSchema
13from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
14from mlos_bench.util import instantiate_from_config
16_LOG = logging.getLogger(__name__)
19class Service:
20 """An abstract base of all Environment Services and used to build up mix-ins."""
22 @classmethod
23 def new(
24 cls,
25 class_name: str,
26 config: Optional[Dict[str, Any]] = None,
27 global_config: Optional[Dict[str, Any]] = None,
28 parent: Optional["Service"] = None,
29 ) -> "Service":
30 """
31 Factory method for a new service with a given config.
33 Parameters
34 ----------
35 class_name: str
36 FQN of a Python class to instantiate, e.g.,
37 "mlos_bench.services.remote.azure.AzureVMService".
38 Must be derived from the `Service` class.
39 config : dict
40 Free-format dictionary that contains the service configuration.
41 It will be passed as a constructor parameter of the class
42 specified by `class_name`.
43 global_config : dict
44 Free-format dictionary of global parameters.
45 parent : Service
46 A parent service that can provide mixin functions.
48 Returns
49 -------
50 svc : Service
51 An instance of the `Service` class initialized with `config`.
52 """
53 assert issubclass(cls, Service)
54 return instantiate_from_config(cls, class_name, config, global_config, parent)
56 def __init__(
57 self,
58 config: Optional[Dict[str, Any]] = None,
59 global_config: Optional[Dict[str, Any]] = None,
60 parent: Optional["Service"] = None,
61 methods: Union[Dict[str, Callable], List[Callable], None] = None,
62 ):
63 """
64 Create a new service with a given config.
66 Parameters
67 ----------
68 config : dict
69 Free-format dictionary that contains the service configuration.
70 It will be passed as a constructor parameter of the class
71 specified by `class_name`.
72 global_config : dict
73 Free-format dictionary of global parameters.
74 parent : Service
75 An optional parent service that can provide mixin functions.
76 methods : Union[Dict[str, Callable], List[Callable], None]
77 New methods to register with the service.
78 """
79 self.config = config or {}
80 self._validate_json_config(self.config)
81 self._parent = parent
82 self._service_methods: Dict[str, Callable] = {}
83 self._services: Set[Service] = set()
84 self._service_contexts: List[Service] = []
85 self._in_context = False
87 if parent:
88 self.register(parent.export())
89 if methods:
90 self.register(methods)
92 self._config_loader_service: SupportsConfigLoading
93 if parent and isinstance(parent, SupportsConfigLoading):
94 self._config_loader_service = parent
96 if _LOG.isEnabledFor(logging.DEBUG):
97 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
98 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
99 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
101 @staticmethod
102 def merge_methods(
103 ext_methods: Union[Dict[str, Callable], List[Callable], None],
104 local_methods: Union[Dict[str, Callable], List[Callable]],
105 ) -> Dict[str, Callable]:
106 """
107 Merge methods from the external caller with the local ones.
109 This function is usually called by the derived class constructor just before
110 invoking the constructor of the base class.
111 """
112 if isinstance(local_methods, dict):
113 local_methods = local_methods.copy()
114 else:
115 local_methods = {svc.__name__: svc for svc in local_methods}
117 if not ext_methods:
118 return local_methods
120 if not isinstance(ext_methods, dict):
121 ext_methods = {svc.__name__: svc for svc in ext_methods}
123 local_methods.update(ext_methods)
124 return local_methods
126 def __enter__(self) -> "Service":
127 """
128 Enter the Service mix-in context.
130 Calls the _enter_context() method of all the Services registered under this one.
131 """
132 if self._in_context:
133 # Multiple environments can share the same Service, so we need to
134 # add a check and make this a re-entrant Service context.
135 assert self._service_contexts
136 assert all(svc._in_context for svc in self._services)
137 return self
138 self._service_contexts = [svc._enter_context() for svc in self._services]
139 self._in_context = True
140 return self
142 def __exit__(
143 self,
144 ex_type: Optional[Type[BaseException]],
145 ex_val: Optional[BaseException],
146 ex_tb: Optional[TracebackType],
147 ) -> Literal[False]:
148 """
149 Exit the Service mix-in context.
151 Calls the _exit_context() method of all the Services registered under this one.
152 """
153 if not self._in_context:
154 # Multiple environments can share the same Service, so we need to
155 # add a check and make this a re-entrant Service context.
156 assert not self._service_contexts
157 assert all(not svc._in_context for svc in self._services)
158 return False
159 ex_throw = None
160 for svc in reversed(self._service_contexts):
161 try:
162 svc._exit_context(ex_type, ex_val, ex_tb)
163 # pylint: disable=broad-exception-caught
164 except Exception as ex:
165 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex)
166 ex_throw = ex
167 self._service_contexts = []
168 if ex_throw:
169 raise ex_throw
170 self._in_context = False
171 return False
173 def _enter_context(self) -> "Service":
174 """
175 Enters the context for this particular Service instance.
177 Called by the base __enter__ method of the Service class so it can be used with
178 mix-ins and overridden by subclasses.
179 """
180 assert not self._in_context
181 self._in_context = True
182 return self
184 def _exit_context(
185 self,
186 ex_type: Optional[Type[BaseException]],
187 ex_val: Optional[BaseException],
188 ex_tb: Optional[TracebackType],
189 ) -> Literal[False]:
190 """
191 Exits the context for this particular Service instance.
193 Called by the base __enter__ method of the Service class so it can be used with
194 mix-ins and overridden by subclasses.
195 """
196 # pylint: disable=unused-argument
197 assert self._in_context
198 self._in_context = False
199 return False
201 def _validate_json_config(self, config: dict) -> None:
202 """Reconstructs a basic json config that this class might have been instantiated
203 from in order to validate configs provided outside the file loading
204 mechanism.
205 """
206 if self.__class__ == Service:
207 # Skip over the case where instantiate a bare base Service class in
208 # order to build up a mix-in.
209 assert config == {}
210 return
211 json_config: dict = {
212 "class": self.__class__.__module__ + "." + self.__class__.__name__,
213 }
214 if config:
215 json_config["config"] = config
216 ConfigSchema.SERVICE.validate(json_config)
218 def __repr__(self) -> str:
219 return f"{self.__class__.__name__}@{hex(id(self))}"
221 def pprint(self) -> str:
222 """Produce a human-readable string listing all public methods of the service."""
223 return f"{self} ::\n" + "\n".join(
224 f' "{key}": {getattr(val, "__self__", "stand-alone")}'
225 for (key, val) in self._service_methods.items()
226 )
228 @property
229 def config_loader_service(self) -> SupportsConfigLoading:
230 """
231 Return a config loader service.
233 Returns
234 -------
235 config_loader_service : SupportsConfigLoading
236 A config loader service.
237 """
238 return self._config_loader_service
240 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None:
241 """
242 Register new mix-in services.
244 Parameters
245 ----------
246 services : dict or list
247 A dictionary of string -> function pairs.
248 """
249 if not isinstance(services, dict):
250 services = {svc.__name__: svc for svc in services}
252 self._service_methods.update(services)
253 self.__dict__.update(self._service_methods)
255 if _LOG.isEnabledFor(logging.DEBUG):
256 _LOG.debug("Added methods to: %s", self.pprint())
258 # In order to get a list of all child contexts, we need to look at only
259 # the bound methods that were not overridden by another mixin.
260 # Then we inspect the internally bound __self__ variable to discover
261 # which Service instance that method belongs too.
262 # To do this we also
264 # All service loading must happen prior to entering a context.
265 assert not self._in_context
266 assert not self._service_contexts
267 self._services = {
268 # Enumerate the Services that are bound to this instance in the
269 # order they were added.
270 # Unfortunately, by creating a set, we may destroy the ability to
271 # preserve the context enter/exit order, but hopefully it doesn't
272 # matter.
273 svc_method.__self__
274 for _, svc_method in self._service_methods.items()
275 # Note: some methods are actually stand alone functions, so we need
276 # to filter them out.
277 if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
278 }
280 def export(self) -> Dict[str, Callable]:
281 """
282 Return a dictionary of functions available in this service.
284 Returns
285 -------
286 services : dict
287 A dictionary of string -> function pairs.
288 """
289 if _LOG.isEnabledFor(logging.DEBUG):
290 _LOG.debug("Export methods from: %s", self.pprint())
292 return self._service_methods