Coverage for mlos_bench/mlos_bench/services/base_service.py: 94%

105 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base class for the service mix-ins. 

7""" 

8 

9import json 

10import logging 

11 

12from types import TracebackType 

13from typing import Any, Callable, Dict, List, Optional, Set, Type, Union 

14from typing_extensions import Literal 

15 

16from mlos_bench.config.schemas import ConfigSchema 

17from mlos_bench.services.types.config_loader_type import SupportsConfigLoading 

18from mlos_bench.util import instantiate_from_config 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class Service: 

24 """ 

25 An abstract base of all Environment Services and used to build up mix-ins. 

26 """ 

27 

28 @classmethod 

29 def new(cls, 

30 class_name: str, 

31 config: Optional[Dict[str, Any]] = None, 

32 global_config: Optional[Dict[str, Any]] = None, 

33 parent: Optional["Service"] = None) -> "Service": 

34 """ 

35 Factory method for a new service with a given config. 

36 

37 Parameters 

38 ---------- 

39 class_name: str 

40 FQN of a Python class to instantiate, e.g., 

41 "mlos_bench.services.remote.azure.AzureVMService". 

42 Must be derived from the `Service` class. 

43 config : dict 

44 Free-format dictionary that contains the service configuration. 

45 It will be passed as a constructor parameter of the class 

46 specified by `class_name`. 

47 global_config : dict 

48 Free-format dictionary of global parameters. 

49 parent : Service 

50 A parent service that can provide mixin functions. 

51 

52 Returns 

53 ------- 

54 svc : Service 

55 An instance of the `Service` class initialized with `config`. 

56 """ 

57 assert issubclass(cls, Service) 

58 return instantiate_from_config(cls, class_name, config, global_config, parent) 

59 

60 def __init__(self, 

61 config: Optional[Dict[str, Any]] = None, 

62 global_config: Optional[Dict[str, Any]] = None, 

63 parent: Optional["Service"] = None, 

64 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

65 """ 

66 Create a new service with a given config. 

67 

68 Parameters 

69 ---------- 

70 config : dict 

71 Free-format dictionary that contains the service configuration. 

72 It will be passed as a constructor parameter of the class 

73 specified by `class_name`. 

74 global_config : dict 

75 Free-format dictionary of global parameters. 

76 parent : Service 

77 An optional parent service that can provide mixin functions. 

78 methods : Union[Dict[str, Callable], List[Callable], None] 

79 New methods to register with the service. 

80 """ 

81 self.config = config or {} 

82 self._validate_json_config(self.config) 

83 self._parent = parent 

84 self._service_methods: Dict[str, Callable] = {} 

85 self._services: Set[Service] = set() 

86 self._service_contexts: List[Service] = [] 

87 self._in_context = False 

88 

89 if parent: 

90 self.register(parent.export()) 

91 if methods: 

92 self.register(methods) 

93 

94 self._config_loader_service: SupportsConfigLoading 

95 if parent and isinstance(parent, SupportsConfigLoading): 

96 self._config_loader_service = parent 

97 

98 if _LOG.isEnabledFor(logging.DEBUG): 

99 _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) 

100 _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) 

101 _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) 

102 

103 @staticmethod 

104 def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], 

105 local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: 

106 """ 

107 Merge methods from the external caller with the local ones. 

108 This function is usually called by the derived class constructor 

109 just before invoking the constructor of the base class. 

110 """ 

111 if isinstance(local_methods, dict): 

112 local_methods = local_methods.copy() 

113 else: 

114 local_methods = {svc.__name__: svc for svc in local_methods} 

115 

116 if not ext_methods: 

117 return local_methods 

118 

119 if not isinstance(ext_methods, dict): 

120 ext_methods = {svc.__name__: svc for svc in ext_methods} 

121 

122 local_methods.update(ext_methods) 

123 return local_methods 

124 

125 def __enter__(self) -> "Service": 

126 """ 

127 Enter the Service mix-in context. 

128 

129 Calls the _enter_context() method of all the Services registered under this one. 

130 """ 

131 if self._in_context: 

132 # Multiple environments can share the same Service, so we need to 

133 # add a check and make this a re-entrant Service context. 

134 assert self._service_contexts 

135 assert all(svc._in_context for svc in self._services) 

136 return self 

137 self._service_contexts = [svc._enter_context() for svc in self._services] 

138 self._in_context = True 

139 return self 

140 

141 def __exit__(self, ex_type: Optional[Type[BaseException]], 

142 ex_val: Optional[BaseException], 

143 ex_tb: Optional[TracebackType]) -> Literal[False]: 

144 """ 

145 Exit the Service mix-in context. 

146 

147 Calls the _exit_context() method of all the Services registered under this one. 

148 """ 

149 if not self._in_context: 

150 # Multiple environments can share the same Service, so we need to 

151 # add a check and make this a re-entrant Service context. 

152 assert not self._service_contexts 

153 assert all(not svc._in_context for svc in self._services) 

154 return False 

155 ex_throw = None 

156 for svc in reversed(self._service_contexts): 

157 try: 

158 svc._exit_context(ex_type, ex_val, ex_tb) 

159 # pylint: disable=broad-exception-caught 

160 except Exception as ex: 

161 _LOG.error("Exception while exiting Service context '%s': %s", svc, ex) 

162 ex_throw = ex 

163 self._service_contexts = [] 

164 if ex_throw: 

165 raise ex_throw 

166 self._in_context = False 

167 return False 

168 

169 def _enter_context(self) -> "Service": 

170 """ 

171 Enters the context for this particular Service instance. 

172 

173 Called by the base __enter__ method of the Service class so it can be 

174 used with mix-ins and overridden by subclasses. 

175 """ 

176 assert not self._in_context 

177 self._in_context = True 

178 return self 

179 

180 def _exit_context(self, ex_type: Optional[Type[BaseException]], 

181 ex_val: Optional[BaseException], 

182 ex_tb: Optional[TracebackType]) -> Literal[False]: 

183 """ 

184 Exits the context for this particular Service instance. 

185 

186 Called by the base __enter__ method of the Service class so it can be 

187 used with mix-ins and overridden by subclasses. 

188 """ 

189 # pylint: disable=unused-argument 

190 assert self._in_context 

191 self._in_context = False 

192 return False 

193 

194 def _validate_json_config(self, config: dict) -> None: 

195 """ 

196 Reconstructs a basic json config that this class might have been 

197 instantiated from in order to validate configs provided outside the 

198 file loading mechanism. 

199 """ 

200 if self.__class__ == Service: 

201 # Skip over the case where instantiate a bare base Service class in order to build up a mix-in. 

202 assert config == {} 

203 return 

204 json_config: dict = { 

205 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

206 } 

207 if config: 

208 json_config["config"] = config 

209 ConfigSchema.SERVICE.validate(json_config) 

210 

211 def __repr__(self) -> str: 

212 return f"{self.__class__.__name__}@{hex(id(self))}" 

213 

214 def pprint(self) -> str: 

215 """ 

216 Produce a human-readable string listing all public methods of the service. 

217 """ 

218 return f"{self} ::\n" + "\n".join( 

219 f' "{key}": {getattr(val, "__self__", "stand-alone")}' 

220 for (key, val) in self._service_methods.items() 

221 ) 

222 

223 @property 

224 def config_loader_service(self) -> SupportsConfigLoading: 

225 """ 

226 Return a config loader service. 

227 

228 Returns 

229 ------- 

230 config_loader_service : SupportsConfigLoading 

231 A config loader service. 

232 """ 

233 return self._config_loader_service 

234 

235 def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None: 

236 """ 

237 Register new mix-in services. 

238 

239 Parameters 

240 ---------- 

241 services : dict or list 

242 A dictionary of string -> function pairs. 

243 """ 

244 if not isinstance(services, dict): 

245 services = {svc.__name__: svc for svc in services} 

246 

247 self._service_methods.update(services) 

248 self.__dict__.update(self._service_methods) 

249 

250 if _LOG.isEnabledFor(logging.DEBUG): 

251 _LOG.debug("Added methods to: %s", self.pprint()) 

252 

253 # In order to get a list of all child contexts, we need to look at only 

254 # the bound methods that were not overridden by another mixin. 

255 # Then we inspect the internally bound __self__ variable to discover 

256 # which Service instance that method belongs too. 

257 # To do this we also 

258 

259 # All service loading must happen prior to entering a context. 

260 assert not self._in_context 

261 assert not self._service_contexts 

262 self._services = { 

263 # Enumerate the Services that are bound to this instance in the 

264 # order they were added. 

265 # Unfortunately, by creating a set, we may destroy the ability to 

266 # preserve the context enter/exit order, but hopefully it doesn't 

267 # matter. 

268 svc_method.__self__ for _, svc_method in self._service_methods.items() 

269 # Note: some methods are actually stand alone functions, so we need 

270 # to filter them out. 

271 if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) 

272 } 

273 

274 def export(self) -> Dict[str, Callable]: 

275 """ 

276 Return a dictionary of functions available in this service. 

277 

278 Returns 

279 ------- 

280 services : dict 

281 A dictionary of string -> function pairs. 

282 """ 

283 if _LOG.isEnabledFor(logging.DEBUG): 

284 _LOG.debug("Export methods from: %s", self.pprint()) 

285 

286 return self._service_methods