Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for configuring SaaS instances on Azure.""" 

6import logging 

7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 

8 

9import requests 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.services.base_service import Service 

13from mlos_bench.services.types.authenticator_type import SupportsAuth 

14from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig 

15from mlos_bench.util import check_required_params, merge_parameters 

16 

17_LOG = logging.getLogger(__name__) 

18 

19 

20class AzureSaaSConfigService(Service, SupportsRemoteConfig): 

21 """Helper methods to configure Azure Flex services.""" 

22 

23 _REQUEST_TIMEOUT = 5 # seconds 

24 

25 # REST API for Azure SaaS DB Services configuration as described in: 

26 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations 

27 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations 

28 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations 

29 

30 _URL_CONFIGURE = ( 

31 "https://management.azure.com" 

32 "/subscriptions/{subscription}" 

33 "/resourceGroups/{resource_group}" 

34 "/providers/{provider}" 

35 "/{server_type}/{vm_name}" 

36 "/{update}" 

37 "?api-version={api_version}" 

38 ) 

39 

40 def __init__( 

41 self, 

42 config: Optional[Dict[str, Any]] = None, 

43 global_config: Optional[Dict[str, Any]] = None, 

44 parent: Optional[Service] = None, 

45 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

46 ): 

47 """ 

48 Create a new instance of Azure services proxy. 

49 

50 Parameters 

51 ---------- 

52 config : dict 

53 Free-format dictionary that contains the benchmark environment 

54 configuration. 

55 global_config : dict 

56 Free-format dictionary of global parameters. 

57 parent : Service 

58 Parent service that can provide mixin functions. 

59 methods : Union[Dict[str, Callable], List[Callable], None] 

60 New methods to register with the service. 

61 """ 

62 super().__init__( 

63 config, 

64 global_config, 

65 parent, 

66 self.merge_methods(methods, [self.configure, self.is_config_pending]), 

67 ) 

68 

69 check_required_params( 

70 self.config, 

71 { 

72 "subscription", 

73 "resourceGroup", 

74 "provider", 

75 }, 

76 ) 

77 

78 # Provide sane defaults for known DB providers. 

79 provider = self.config.get("provider") 

80 if provider == "Microsoft.DBforMySQL": 

81 self._is_batch = self.config.get("supportsBatchUpdate", True) 

82 is_flex = self.config.get("isFlex", True) 

83 api_version = self.config.get("apiVersion", "2022-01-01") 

84 elif provider == "Microsoft.DBforMariaDB": 

85 self._is_batch = self.config.get("supportsBatchUpdate", False) 

86 is_flex = self.config.get("isFlex", False) 

87 api_version = self.config.get("apiVersion", "2018-06-01") 

88 elif provider == "Microsoft.DBforPostgreSQL": 

89 self._is_batch = self.config.get("supportsBatchUpdate", False) 

90 is_flex = self.config.get("isFlex", True) 

91 api_version = self.config.get("apiVersion", "2022-12-01") 

92 else: 

93 self._is_batch = self.config["supportsBatchUpdate"] 

94 is_flex = self.config["isFlex"] 

95 api_version = self.config["apiVersion"] 

96 

97 self._url_config_set = self._URL_CONFIGURE.format( 

98 subscription=self.config["subscription"], 

99 resource_group=self.config["resourceGroup"], 

100 provider=self.config["provider"], 

101 vm_name="{vm_name}", 

102 server_type="flexibleServers" if is_flex else "servers", 

103 update="updateConfigurations" if self._is_batch else "configurations/{param_name}", 

104 api_version=api_version, 

105 ) 

106 

107 self._url_config_get = self._URL_CONFIGURE.format( 

108 subscription=self.config["subscription"], 

109 resource_group=self.config["resourceGroup"], 

110 provider=self.config["provider"], 

111 vm_name="{vm_name}", 

112 server_type="flexibleServers" if is_flex else "servers", 

113 update="configurations", 

114 api_version=api_version, 

115 ) 

116 

117 # These parameters can come from command line as strings, so conversion is needed. 

118 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

119 

120 def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: 

121 """ 

122 Update the parameters of an Azure DB service. 

123 

124 Parameters 

125 ---------- 

126 config : Dict[str, Any] 

127 Key/value pairs of configuration parameters (e.g., vmName). 

128 params : Dict[str, Any] 

129 Key/value pairs of the service parameters to update. 

130 

131 Returns 

132 ------- 

133 result : (Status, dict) 

134 A pair of Status and result. The result is always {}. 

135 Status is one of {PENDING, SUCCEEDED, FAILED} 

136 """ 

137 if self._is_batch: 

138 return self._config_batch(config, params) 

139 return self._config_many(config, params) 

140 

141 def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: 

142 """ 

143 Check if the configuration of an Azure DB service requires a reboot or restart. 

144 

145 Parameters 

146 ---------- 

147 config : Dict[str, Any] 

148 Key/value pairs of configuration parameters (e.g., vmName). 

149 

150 Returns 

151 ------- 

152 result : (Status, dict) 

153 A pair of Status and result. A Boolean field 

154 "isConfigPendingRestart" indicates whether the service restart is required. 

155 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. 

156 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} 

157 """ 

158 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

159 url = self._url_config_get.format(vm_name=config["vmName"]) 

160 _LOG.debug("Request: GET %s", url) 

161 response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout) 

162 _LOG.debug("Response: %s :: %s", response, response.text) 

163 if response.status_code == 504: 

164 return (Status.TIMED_OUT, {}) 

165 if response.status_code != 200: 

166 return (Status.FAILED, {}) 

167 # Currently, Azure Flex servers require a VM reboot. 

168 return ( 

169 Status.SUCCEEDED, 

170 { 

171 "isConfigPendingReboot": any( 

172 {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]] 

173 for val in response.json()["value"] 

174 ) 

175 }, 

176 ) 

177 

178 def _get_headers(self) -> dict: 

179 """Get the headers for the REST API calls.""" 

180 assert self._parent is not None and isinstance( 

181 self._parent, SupportsAuth 

182 ), "Authorization service not provided. Include service-auth.jsonc?" 

183 return self._parent.get_auth_headers() 

184 

185 def _config_one( 

186 self, 

187 config: Dict[str, Any], 

188 param_name: str, 

189 param_value: Any, 

190 ) -> Tuple[Status, dict]: 

191 """ 

192 Update a single parameter of the Azure DB service. 

193 

194 Parameters 

195 ---------- 

196 config : Dict[str, Any] 

197 Key/value pairs of configuration parameters (e.g., vmName). 

198 param_name : str 

199 Name of the parameter to update. 

200 param_value : Any 

201 Value of the parameter to update. 

202 

203 Returns 

204 ------- 

205 result : (Status, dict) 

206 A pair of Status and result. The result is always {}. 

207 Status is one of {PENDING, SUCCEEDED, FAILED} 

208 """ 

209 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

210 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) 

211 _LOG.debug("Request: PUT %s", url) 

212 response = requests.put( 

213 url, 

214 headers=self._get_headers(), 

215 json={"properties": {"value": str(param_value)}}, 

216 timeout=self._request_timeout, 

217 ) 

218 _LOG.debug("Response: %s :: %s", response, response.text) 

219 if response.status_code == 504: 

220 return (Status.TIMED_OUT, {}) 

221 if response.status_code == 200: 

222 return (Status.SUCCEEDED, {}) 

223 return (Status.FAILED, {}) 

224 

225 def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: 

226 """ 

227 Update the parameters of an Azure DB service one-by-one. (If batch API is not 

228 available for it). 

229 

230 Parameters 

231 ---------- 

232 config : Dict[str, Any] 

233 Key/value pairs of configuration parameters (e.g., vmName). 

234 params : Dict[str, Any] 

235 Key/value pairs of the service parameters to update. 

236 

237 Returns 

238 ------- 

239 result : (Status, dict) 

240 A pair of Status and result. The result is always {}. 

241 Status is one of {PENDING, SUCCEEDED, FAILED} 

242 """ 

243 for param_name, param_value in params.items(): 

244 (status, result) = self._config_one(config, param_name, param_value) 

245 if not status.is_succeeded(): 

246 return (status, result) 

247 return (Status.SUCCEEDED, {}) 

248 

249 def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: 

250 """ 

251 Batch update the parameters of an Azure DB service. 

252 

253 Parameters 

254 ---------- 

255 config : Dict[str, Any] 

256 Key/value pairs of configuration parameters (e.g., vmName). 

257 params : Dict[str, Any] 

258 Key/value pairs of the service parameters to update. 

259 

260 Returns 

261 ------- 

262 result : (Status, dict) 

263 A pair of Status and result. The result is always {}. 

264 Status is one of {PENDING, SUCCEEDED, FAILED} 

265 """ 

266 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

267 url = self._url_config_set.format(vm_name=config["vmName"]) 

268 json_req = { 

269 "value": [ 

270 {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items() 

271 ], 

272 # "resetAllToDefault": "True" 

273 } 

274 _LOG.debug("Request: POST %s", url) 

275 response = requests.post( 

276 url, 

277 headers=self._get_headers(), 

278 json=json_req, 

279 timeout=self._request_timeout, 

280 ) 

281 _LOG.debug("Response: %s :: %s", response, response.text) 

282 if response.status_code == 504: 

283 return (Status.TIMED_OUT, {}) 

284 if response.status_code == 200: 

285 return (Status.SUCCEEDED, {}) 

286 return (Status.FAILED, {})