Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%

82 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection Service functions for configuring SaaS instances on Azure. 

7""" 

8import logging 

9 

10from typing import Any, Callable, Dict, List, Optional, Tuple, Union 

11 

12import requests 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.services.base_service import Service 

16from mlos_bench.services.types.authenticator_type import SupportsAuth 

17from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig 

18from mlos_bench.util import check_required_params, merge_parameters 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class AzureSaaSConfigService(Service, SupportsRemoteConfig): 

24 """ 

25 Helper methods to configure Azure Flex services. 

26 """ 

27 

28 _REQUEST_TIMEOUT = 5 # seconds 

29 

30 # REST API for Azure SaaS DB Services configuration as described in: 

31 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations 

32 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations 

33 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations 

34 

35 _URL_CONFIGURE = ( 

36 "https://management.azure.com" + 

37 "/subscriptions/{subscription}" + 

38 "/resourceGroups/{resource_group}" + 

39 "/providers/{provider}" + 

40 "/{server_type}/{vm_name}" + 

41 "/{update}" + 

42 "?api-version={api_version}" 

43 ) 

44 

45 def __init__(self, 

46 config: Optional[Dict[str, Any]] = None, 

47 global_config: Optional[Dict[str, Any]] = None, 

48 parent: Optional[Service] = None, 

49 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

50 """ 

51 Create a new instance of Azure services proxy. 

52 

53 Parameters 

54 ---------- 

55 config : dict 

56 Free-format dictionary that contains the benchmark environment 

57 configuration. 

58 global_config : dict 

59 Free-format dictionary of global parameters. 

60 parent : Service 

61 Parent service that can provide mixin functions. 

62 methods : Union[Dict[str, Callable], List[Callable], None] 

63 New methods to register with the service. 

64 """ 

65 super().__init__( 

66 config, global_config, parent, 

67 self.merge_methods(methods, [ 

68 self.configure, 

69 self.is_config_pending 

70 ]) 

71 ) 

72 

73 check_required_params(self.config, { 

74 "subscription", 

75 "resourceGroup", 

76 "provider", 

77 }) 

78 

79 # Provide sane defaults for known DB providers. 

80 provider = self.config.get("provider") 

81 if provider == "Microsoft.DBforMySQL": 

82 self._is_batch = self.config.get("supportsBatchUpdate", True) 

83 is_flex = self.config.get("isFlex", True) 

84 api_version = self.config.get("apiVersion", "2022-01-01") 

85 elif provider == "Microsoft.DBforMariaDB": 

86 self._is_batch = self.config.get("supportsBatchUpdate", False) 

87 is_flex = self.config.get("isFlex", False) 

88 api_version = self.config.get("apiVersion", "2018-06-01") 

89 elif provider == "Microsoft.DBforPostgreSQL": 

90 self._is_batch = self.config.get("supportsBatchUpdate", False) 

91 is_flex = self.config.get("isFlex", True) 

92 api_version = self.config.get("apiVersion", "2022-12-01") 

93 else: 

94 self._is_batch = self.config["supportsBatchUpdate"] 

95 is_flex = self.config["isFlex"] 

96 api_version = self.config["apiVersion"] 

97 

98 self._url_config_set = self._URL_CONFIGURE.format( 

99 subscription=self.config["subscription"], 

100 resource_group=self.config["resourceGroup"], 

101 provider=self.config["provider"], 

102 vm_name="{vm_name}", 

103 server_type="flexibleServers" if is_flex else "servers", 

104 update="updateConfigurations" if self._is_batch else "configurations/{param_name}", 

105 api_version=api_version, 

106 ) 

107 

108 self._url_config_get = self._URL_CONFIGURE.format( 

109 subscription=self.config["subscription"], 

110 resource_group=self.config["resourceGroup"], 

111 provider=self.config["provider"], 

112 vm_name="{vm_name}", 

113 server_type="flexibleServers" if is_flex else "servers", 

114 update="configurations", 

115 api_version=api_version, 

116 ) 

117 

118 # These parameters can come from command line as strings, so conversion is needed. 

119 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

120 

121 def configure(self, config: Dict[str, Any], 

122 params: Dict[str, Any]) -> Tuple[Status, dict]: 

123 """ 

124 Update the parameters of an Azure DB service. 

125 

126 Parameters 

127 ---------- 

128 config : Dict[str, Any] 

129 Key/value pairs of configuration parameters (e.g., vmName). 

130 params : Dict[str, Any] 

131 Key/value pairs of the service parameters to update. 

132 

133 Returns 

134 ------- 

135 result : (Status, dict={}) 

136 A pair of Status and result. The result is always {}. 

137 Status is one of {PENDING, SUCCEEDED, FAILED} 

138 """ 

139 if self._is_batch: 

140 return self._config_batch(config, params) 

141 return self._config_many(config, params) 

142 

143 def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: 

144 """ 

145 Check if the configuration of an Azure DB service requires a reboot or restart. 

146 

147 Parameters 

148 ---------- 

149 config : Dict[str, Any] 

150 Key/value pairs of configuration parameters (e.g., vmName). 

151 

152 Returns 

153 ------- 

154 result : (Status, dict) 

155 A pair of Status and result. A Boolean field 

156 "isConfigPendingRestart" indicates whether the service restart is required. 

157 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. 

158 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} 

159 """ 

160 config = merge_parameters( 

161 dest=self.config.copy(), source=config, required_keys=["vmName"]) 

162 url = self._url_config_get.format(vm_name=config["vmName"]) 

163 _LOG.debug("Request: GET %s", url) 

164 response = requests.put( 

165 url, headers=self._get_headers(), timeout=self._request_timeout) 

166 _LOG.debug("Response: %s :: %s", response, response.text) 

167 if response.status_code == 504: 

168 return (Status.TIMED_OUT, {}) 

169 if response.status_code != 200: 

170 return (Status.FAILED, {}) 

171 # Currently, Azure Flex servers require a VM reboot. 

172 return (Status.SUCCEEDED, {"isConfigPendingReboot": any( 

173 {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] 

174 for val in response.json()['value'] 

175 )}) 

176 

177 def _get_headers(self) -> dict: 

178 """ 

179 Get the headers for the REST API calls. 

180 """ 

181 assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ 

182 "Authorization service not provided. Include service-auth.jsonc?" 

183 return self._parent.get_auth_headers() 

184 

185 def _config_one(self, config: Dict[str, Any], 

186 param_name: str, param_value: Any) -> Tuple[Status, dict]: 

187 """ 

188 Update a single parameter of the Azure DB service. 

189 

190 Parameters 

191 ---------- 

192 config : Dict[str, Any] 

193 Key/value pairs of configuration parameters (e.g., vmName). 

194 param_name : str 

195 Name of the parameter to update. 

196 param_value : Any 

197 Value of the parameter to update. 

198 

199 Returns 

200 ------- 

201 result : (Status, dict={}) 

202 A pair of Status and result. The result is always {}. 

203 Status is one of {PENDING, SUCCEEDED, FAILED} 

204 """ 

205 config = merge_parameters( 

206 dest=self.config.copy(), source=config, required_keys=["vmName"]) 

207 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) 

208 _LOG.debug("Request: PUT %s", url) 

209 response = requests.put(url, headers=self._get_headers(), 

210 json={"properties": {"value": str(param_value)}}, 

211 timeout=self._request_timeout) 

212 _LOG.debug("Response: %s :: %s", response, response.text) 

213 if response.status_code == 504: 

214 return (Status.TIMED_OUT, {}) 

215 if response.status_code == 200: 

216 return (Status.SUCCEEDED, {}) 

217 return (Status.FAILED, {}) 

218 

219 def _config_many(self, config: Dict[str, Any], 

220 params: Dict[str, Any]) -> Tuple[Status, dict]: 

221 """ 

222 Update the parameters of an Azure DB service one-by-one. 

223 (If batch API is not available for it). 

224 

225 Parameters 

226 ---------- 

227 config : Dict[str, Any] 

228 Key/value pairs of configuration parameters (e.g., vmName). 

229 params : Dict[str, Any] 

230 Key/value pairs of the service parameters to update. 

231 

232 Returns 

233 ------- 

234 result : (Status, dict={}) 

235 A pair of Status and result. The result is always {}. 

236 Status is one of {PENDING, SUCCEEDED, FAILED} 

237 """ 

238 for (param_name, param_value) in params.items(): 

239 (status, result) = self._config_one(config, param_name, param_value) 

240 if not status.is_succeeded(): 

241 return (status, result) 

242 return (Status.SUCCEEDED, {}) 

243 

244 def _config_batch(self, config: Dict[str, Any], 

245 params: Dict[str, Any]) -> Tuple[Status, dict]: 

246 """ 

247 Batch update the parameters of an Azure DB service. 

248 

249 Parameters 

250 ---------- 

251 config : Dict[str, Any] 

252 Key/value pairs of configuration parameters (e.g., vmName). 

253 params : Dict[str, Any] 

254 Key/value pairs of the service parameters to update. 

255 

256 Returns 

257 ------- 

258 result : (Status, dict={}) 

259 A pair of Status and result. The result is always {}. 

260 Status is one of {PENDING, SUCCEEDED, FAILED} 

261 """ 

262 config = merge_parameters( 

263 dest=self.config.copy(), source=config, required_keys=["vmName"]) 

264 url = self._url_config_set.format(vm_name=config["vmName"]) 

265 json_req = { 

266 "value": [ 

267 {"name": key, "properties": {"value": str(val)}} 

268 for (key, val) in params.items() 

269 ], 

270 # "resetAllToDefault": "True" 

271 } 

272 _LOG.debug("Request: POST %s", url) 

273 response = requests.post(url, headers=self._get_headers(), 

274 json=json_req, timeout=self._request_timeout) 

275 _LOG.debug("Response: %s :: %s", response, response.text) 

276 if response.status_code == 504: 

277 return (Status.TIMED_OUT, {}) 

278 if response.status_code == 200: 

279 return (Status.SUCCEEDED, {}) 

280 return (Status.FAILED, {})