Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%

179 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base class for certain Azure Services classes that do deployments. 

7""" 

8 

9import abc 

10import json 

11import time 

12import logging 

13 

14from typing import Any, Callable, Dict, List, Optional, Tuple, Union 

15 

16import requests 

17from requests.adapters import HTTPAdapter, Retry 

18 

19from mlos_bench.dict_templater import DictTemplater 

20from mlos_bench.environments.status import Status 

21from mlos_bench.services.base_service import Service 

22from mlos_bench.services.types.authenticator_type import SupportsAuth 

23from mlos_bench.util import check_required_params, merge_parameters 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class AzureDeploymentService(Service, metaclass=abc.ABCMeta): 

29 """ 

30 Helper methods to manage and deploy Azure resources via REST APIs. 

31 """ 

32 

33 _POLL_INTERVAL = 4 # seconds 

34 _POLL_TIMEOUT = 300 # seconds 

35 _REQUEST_TIMEOUT = 5 # seconds 

36 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request 

37 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) 

38 

39 # Azure Resources Deployment REST API as described in 

40 # https://docs.microsoft.com/en-us/rest/api/resources/deployments 

41 

42 _URL_DEPLOY = ( 

43 "https://management.azure.com" + 

44 "/subscriptions/{subscription}" + 

45 "/resourceGroups/{resource_group}" + 

46 "/providers/Microsoft.Resources" + 

47 "/deployments/{deployment_name}" + 

48 "?api-version=2022-05-01" 

49 ) 

50 

51 def __init__(self, 

52 config: Optional[Dict[str, Any]] = None, 

53 global_config: Optional[Dict[str, Any]] = None, 

54 parent: Optional[Service] = None, 

55 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

56 """ 

57 Create a new instance of an Azure Services proxy. 

58 

59 Parameters 

60 ---------- 

61 config : dict 

62 Free-format dictionary that contains the benchmark environment 

63 configuration. 

64 global_config : dict 

65 Free-format dictionary of global parameters. 

66 parent : Service 

67 Parent service that can provide mixin functions. 

68 methods : Union[Dict[str, Callable], List[Callable], None] 

69 New methods to register with the service. 

70 """ 

71 super().__init__(config, global_config, parent, methods) 

72 

73 check_required_params(self.config, [ 

74 "subscription", 

75 "resourceGroup", 

76 ]) 

77 

78 # These parameters can come from command line as strings, so conversion is needed. 

79 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) 

80 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) 

81 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

82 self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) 

83 self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) 

84 

85 self._deploy_template = {} 

86 self._deploy_params = {} 

87 if self.config.get("deploymentTemplatePath") is not None: 

88 # TODO: Provide external schema validation? 

89 template = self.config_loader_service.load_config( 

90 self.config['deploymentTemplatePath'], schema_type=None) 

91 assert template is not None and isinstance(template, dict) 

92 self._deploy_template = template 

93 

94 # Allow for recursive variable expansion as we do with global params and const_args. 

95 deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) 

96 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) 

97 else: 

98 _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") 

99 

100 @property 

101 def deploy_params(self) -> dict: 

102 """ 

103 Get the deployment parameters. 

104 """ 

105 return self._deploy_params 

106 

107 @abc.abstractmethod 

108 def _set_default_params(self, params: dict) -> dict: 

109 """ 

110 Optionally set some default parameters for the request. 

111 

112 Parameters 

113 ---------- 

114 params : dict 

115 The parameters. 

116 

117 Returns 

118 ------- 

119 dict 

120 The updated parameters. 

121 """ 

122 raise NotImplementedError("Should be overridden by subclass.") 

123 

124 def _get_session(self, params: dict) -> requests.Session: 

125 """ 

126 Get a session object that includes automatic retries and headers for REST API calls. 

127 """ 

128 total_retries = params.get("requestTotalRetries", self._total_retries) 

129 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor) 

130 session = requests.Session() 

131 session.mount( 

132 "https://", 

133 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) 

134 session.headers.update(self._get_headers()) 

135 return session 

136 

137 def _get_headers(self) -> dict: 

138 """ 

139 Get the headers for the REST API calls. 

140 """ 

141 assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ 

142 "Authorization service not provided. Include service-auth.jsonc?" 

143 return self._parent.get_auth_headers() 

144 

145 @staticmethod 

146 def _extract_arm_parameters(json_data: dict) -> dict: 

147 """ 

148 Extract parameters from the ARM Template REST response JSON. 

149 

150 Returns 

151 ------- 

152 parameters : dict 

153 Flat dictionary of parameters and their values. 

154 """ 

155 return { 

156 key: val.get("value") 

157 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() 

158 if val.get("value") is not None 

159 } 

160 

161 def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]: 

162 """ 

163 General pattern for performing an action on an Azure resource via its REST API. 

164 

165 Parameters 

166 ---------- 

167 params: dict 

168 Flat dictionary of (key, value) pairs of tunable parameters. 

169 url: str 

170 REST API url for the target to perform on the Azure VM. 

171 Should be a url that we intend to POST to. 

172 

173 Returns 

174 ------- 

175 result : (Status, dict={}) 

176 A pair of Status and result. 

177 Status is one of {PENDING, SUCCEEDED, FAILED} 

178 Result will have a value for 'asyncResultsUrl' if status is PENDING, 

179 and 'pollInterval' if suggested by the API. 

180 """ 

181 _LOG.debug("Request: POST %s", url) 

182 

183 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) 

184 _LOG.debug("Response: %s", response) 

185 

186 # Logical flow for async operations based on: 

187 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations 

188 if response.status_code == 200: 

189 return (Status.SUCCEEDED, params.copy()) 

190 elif response.status_code == 202: 

191 result = params.copy() 

192 if "Azure-AsyncOperation" in response.headers: 

193 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation") 

194 elif "Location" in response.headers: 

195 result["asyncResultsUrl"] = response.headers.get("Location") 

196 if "Retry-After" in response.headers: 

197 result["pollInterval"] = float(response.headers["Retry-After"]) 

198 

199 return (Status.PENDING, result) 

200 else: 

201 _LOG.error("Response: %s :: %s", response, response.text) 

202 # _LOG.error("Bad Request:\n%s", response.request.body) 

203 return (Status.FAILED, {}) 

204 

205 def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: 

206 """ 

207 Checks the status of a pending operation on an Azure resource. 

208 

209 Parameters 

210 ---------- 

211 params: dict 

212 Flat dictionary of (key, value) pairs of tunable parameters. 

213 Must have the "asyncResultsUrl" key to get the results. 

214 If the key is not present, return Status.PENDING. 

215 

216 Returns 

217 ------- 

218 result : (Status, dict) 

219 A pair of Status and result. 

220 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

221 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

222 """ 

223 url = params.get("asyncResultsUrl") 

224 if url is None: 

225 return Status.PENDING, {} 

226 

227 session = self._get_session(params) 

228 try: 

229 response = session.get(url, timeout=self._request_timeout) 

230 except requests.exceptions.ReadTimeout: 

231 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

232 return Status.RUNNING, {} 

233 except requests.exceptions.RequestException as ex: 

234 _LOG.exception("Error in request checking operation status", exc_info=ex) 

235 return (Status.FAILED, {}) 

236 

237 if _LOG.isEnabledFor(logging.DEBUG): 

238 _LOG.debug("Response: %s\n%s", response, 

239 json.dumps(response.json(), indent=2) 

240 if response.content else "") 

241 

242 if response.status_code == 200: 

243 output = response.json() 

244 status = output.get("status") 

245 if status == "InProgress": 

246 return Status.RUNNING, {} 

247 elif status == "Succeeded": 

248 return Status.SUCCEEDED, output 

249 

250 _LOG.error("Response: %s :: %s", response, response.text) 

251 return Status.FAILED, {} 

252 

253 def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: 

254 """ 

255 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED. 

256 Return TIMED_OUT when timing out. 

257 

258 Parameters 

259 ---------- 

260 params : dict 

261 Flat dictionary of (key, value) pairs of tunable parameters. 

262 is_setup : bool 

263 If True, wait for resource being deployed; otherwise, wait for successful deprovisioning. 

264 

265 Returns 

266 ------- 

267 result : (Status, dict) 

268 A pair of Status and result. 

269 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

270 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

271 """ 

272 params = self._set_default_params(params) 

273 _LOG.info("Wait for %s to %s", params.get("deploymentName"), 

274 "provision" if is_setup else "deprovision") 

275 return self._wait_while(self._check_deployment, Status.PENDING, params) 

276 

277 def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], 

278 loop_status: Status, params: dict) -> Tuple[Status, dict]: 

279 """ 

280 Invoke `func` periodically while the status is equal to `loop_status`. 

281 Return TIMED_OUT when timing out. 

282 

283 Parameters 

284 ---------- 

285 func : a function 

286 A function that takes `params` and returns a pair of (Status, {}) 

287 loop_status: Status 

288 Steady state status - keep polling `func` while it returns `loop_status`. 

289 params : dict 

290 Flat dictionary of (key, value) pairs of tunable parameters. 

291 Requires deploymentName. 

292 

293 Returns 

294 ------- 

295 result : (Status, dict) 

296 A pair of Status and result. 

297 """ 

298 params = self._set_default_params(params) 

299 config = merge_parameters( 

300 dest=self.config.copy(), source=params, required_keys=["deploymentName"]) 

301 

302 poll_period = params.get("pollInterval", self._poll_interval) 

303 

304 _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", 

305 config["deploymentName"], loop_status, poll_period, self._poll_timeout) 

306 

307 ts_timeout = time.time() + self._poll_timeout 

308 poll_delay = poll_period 

309 while True: 

310 # Wait for the suggested time first then check status 

311 ts_start = time.time() 

312 if ts_start >= ts_timeout: 

313 break 

314 

315 if poll_delay > 0: 

316 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period) 

317 time.sleep(poll_delay) 

318 

319 (status, output) = func(params) 

320 if status != loop_status: 

321 return status, output 

322 

323 ts_end = time.time() 

324 poll_delay = poll_period - ts_end + ts_start 

325 

326 _LOG.warning("Request timed out: %s", params) 

327 return (Status.TIMED_OUT, {}) 

328 

329 def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements 

330 """ 

331 Check if Azure deployment exists. 

332 Return SUCCEEDED if true, PENDING otherwise. 

333 

334 Parameters 

335 ---------- 

336 _params : dict 

337 Flat dictionary of (key, value) pairs of tunable parameters. 

338 This parameter is not used; we need it for compatibility with 

339 other polling functions used in `_wait_while()`. 

340 

341 Returns 

342 ------- 

343 result : (Status, dict={}) 

344 A pair of Status and result. The result is always {}. 

345 Status is one of {SUCCEEDED, PENDING, FAILED} 

346 """ 

347 params = self._set_default_params(params) 

348 config = merge_parameters( 

349 dest=self.config.copy(), 

350 source=params, 

351 required_keys=[ 

352 "subscription", 

353 "resourceGroup", 

354 "deploymentName", 

355 ] 

356 ) 

357 

358 _LOG.info("Check deployment: %s", config["deploymentName"]) 

359 

360 url = self._URL_DEPLOY.format( 

361 subscription=config["subscription"], 

362 resource_group=config["resourceGroup"], 

363 deployment_name=config["deploymentName"], 

364 ) 

365 

366 session = self._get_session(params) 

367 try: 

368 response = session.get(url, timeout=self._request_timeout) 

369 except requests.exceptions.ReadTimeout: 

370 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

371 return Status.RUNNING, {} 

372 except requests.exceptions.RequestException as ex: 

373 _LOG.exception("Error in request checking deployment", exc_info=ex) 

374 return (Status.FAILED, {}) 

375 

376 _LOG.debug("Response: %s", response) 

377 

378 if response.status_code == 200: 

379 output = response.json() 

380 state = output.get("properties", {}).get("provisioningState", "") 

381 

382 if state == "Succeeded": 

383 return (Status.SUCCEEDED, {}) 

384 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}: 

385 return (Status.PENDING, {}) 

386 else: 

387 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2)) 

388 return (Status.FAILED, {}) 

389 elif response.status_code == 404: 

390 return (Status.PENDING, {}) 

391 

392 _LOG.error("Response: %s :: %s", response, response.text) 

393 return (Status.FAILED, {}) 

394 

395 def _provision_resource(self, params: dict) -> Tuple[Status, dict]: 

396 """ 

397 Attempts to (re)deploy a resource. 

398 

399 Parameters 

400 ---------- 

401 params : dict 

402 Flat dictionary of (key, value) pairs of tunable parameters. 

403 Tunables are variable parameters that, together with the 

404 Environment configuration, are sufficient to provision the resource. 

405 

406 Returns 

407 ------- 

408 result : (Status, dict={}) 

409 A pair of Status and result. The result is the input `params` plus the 

410 parameters extracted from the response JSON, or {} if the status is FAILED. 

411 Status is one of {PENDING, SUCCEEDED, FAILED} 

412 """ 

413 if not self._deploy_template: 

414 raise ValueError(f"Missing deployment template: {self}") 

415 params = self._set_default_params(params) 

416 config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) 

417 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) 

418 

419 params = merge_parameters(dest=self._deploy_params.copy(), source=params) 

420 if _LOG.isEnabledFor(logging.DEBUG): 

421 _LOG.debug("Deploy: %s merged params ::\n%s", 

422 config["deploymentName"], json.dumps(params, indent=2)) 

423 

424 url = self._URL_DEPLOY.format( 

425 subscription=config["subscription"], 

426 resource_group=config["resourceGroup"], 

427 deployment_name=config["deploymentName"], 

428 ) 

429 

430 json_req = { 

431 "properties": { 

432 "mode": "Incremental", 

433 "template": self._deploy_template, 

434 "parameters": { 

435 key: {"value": val} for (key, val) in params.items() 

436 if key in self._deploy_template.get("parameters", {}) 

437 } 

438 } 

439 } 

440 

441 if _LOG.isEnabledFor(logging.DEBUG): 

442 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

443 

444 response = requests.put(url, json=json_req, 

445 headers=self._get_headers(), timeout=self._request_timeout) 

446 

447 if _LOG.isEnabledFor(logging.DEBUG): 

448 _LOG.debug("Response: %s\n%s", response, 

449 json.dumps(response.json(), indent=2) 

450 if response.content else "") 

451 else: 

452 _LOG.info("Response: %s", response) 

453 

454 if response.status_code == 200: 

455 return (Status.PENDING, config) 

456 elif response.status_code == 201: 

457 output = self._extract_arm_parameters(response.json()) 

458 if _LOG.isEnabledFor(logging.DEBUG): 

459 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2)) 

460 params.update(output) 

461 params.setdefault("asyncResultsUrl", url) 

462 params.setdefault("deploymentName", config["deploymentName"]) 

463 return (Status.PENDING, params) 

464 else: 

465 _LOG.error("Response: %s :: %s", response, response.text) 

466 # _LOG.error("Bad Request:\n%s", response.request.body) 

467 return (Status.FAILED, {})