Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%

179 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for certain Azure Services classes that do deployments.""" 

6 

7import abc 

8import json 

9import logging 

10import time 

11from collections.abc import Callable 

12from typing import Any 

13 

14import requests 

15from requests.adapters import HTTPAdapter, Retry 

16 

17from mlos_bench.dict_templater import DictTemplater 

18from mlos_bench.environments.status import Status 

19from mlos_bench.services.base_service import Service 

20from mlos_bench.services.types.authenticator_type import SupportsAuth 

21from mlos_bench.util import check_required_params, merge_parameters 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class AzureDeploymentService(Service, metaclass=abc.ABCMeta): 

27 """Helper methods to manage and deploy Azure resources via REST APIs.""" 

28 

29 _POLL_INTERVAL = 4 # seconds 

30 _POLL_TIMEOUT = 300 # seconds 

31 _REQUEST_TIMEOUT = 5 # seconds 

32 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request 

33 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) 

34 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 

35 

36 # Azure Resources Deployment REST API as described in 

37 # https://docs.microsoft.com/en-us/rest/api/resources/deployments 

38 

39 _URL_DEPLOY = ( 

40 "https://management.azure.com" 

41 "/subscriptions/{subscription}" 

42 "/resourceGroups/{resource_group}" 

43 "/providers/Microsoft.Resources" 

44 "/deployments/{deployment_name}" 

45 "?api-version=2022-05-01" 

46 ) 

47 

48 def __init__( 

49 self, 

50 config: dict[str, Any] | None = None, 

51 global_config: dict[str, Any] | None = None, 

52 parent: Service | None = None, 

53 methods: dict[str, Callable] | list[Callable] | None = None, 

54 ): 

55 """ 

56 Create a new instance of an Azure Services proxy. 

57 

58 Parameters 

59 ---------- 

60 config : dict 

61 Free-format dictionary that contains the benchmark environment 

62 configuration. 

63 global_config : dict 

64 Free-format dictionary of global parameters. 

65 parent : Service 

66 Parent service that can provide mixin functions. 

67 methods : Union[dict[str, Callable], list[Callable], None] 

68 New methods to register with the service. 

69 """ 

70 super().__init__(config, global_config, parent, methods) 

71 

72 check_required_params( 

73 self.config, 

74 [ 

75 "subscription", 

76 "resourceGroup", 

77 ], 

78 ) 

79 

80 # These parameters can come from command line as strings, so conversion is needed. 

81 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) 

82 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) 

83 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

84 self._total_retries = int( 

85 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) 

86 ) 

87 self._backoff_factor = float( 

88 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) 

89 ) 

90 

91 self._deploy_template = {} 

92 self._deploy_params = {} 

93 if self.config.get("deploymentTemplatePath") is not None: 

94 # TODO: Provide external schema validation? 

95 template = self.config_loader_service.load_config( 

96 self.config["deploymentTemplatePath"], 

97 schema_type=None, 

98 ) 

99 assert template is not None and isinstance(template, dict) 

100 self._deploy_template = template 

101 

102 # Allow for recursive variable expansion as we do with global params and const_args. 

103 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( 

104 extra_source_dict=global_config 

105 ) 

106 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) 

107 else: 

108 _LOG.info( 

109 "No deploymentTemplatePath provided. Deployment services will be unavailable.", 

110 ) 

111 

112 @property 

113 def deploy_params(self) -> dict: 

114 """Get the deployment parameters.""" 

115 return self._deploy_params 

116 

117 @abc.abstractmethod 

118 def _set_default_params(self, params: dict) -> dict: 

119 """ 

120 Optionally set some default parameters for the request. 

121 

122 Parameters 

123 ---------- 

124 params : dict 

125 The parameters. 

126 

127 Returns 

128 ------- 

129 dict 

130 The updated parameters. 

131 """ 

132 raise NotImplementedError("Should be overridden by subclass.") 

133 

134 def _get_session(self, params: dict) -> requests.Session: 

135 """Get a session object that includes automatic retries and headers for REST API 

136 calls. 

137 """ 

138 total_retries = params.get("requestTotalRetries", self._total_retries) 

139 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor) 

140 session = requests.Session() 

141 session.mount( 

142 "https://", 

143 HTTPAdapter( 

144 max_retries=Retry( 

145 total=total_retries, backoff_factor=backoff_factor, status_forcelist=[503] 

146 ) 

147 ), 

148 ) 

149 session.headers.update(self._get_headers()) 

150 return session 

151 

152 def _get_headers(self) -> dict: 

153 """Get the headers for the REST API calls.""" 

154 assert self._parent is not None and isinstance( 

155 self._parent, SupportsAuth 

156 ), "Authorization service not provided. Include service-auth.jsonc?" 

157 return self._parent.get_auth_headers() 

158 

159 @staticmethod 

160 def _extract_arm_parameters(json_data: dict) -> dict: 

161 """ 

162 Extract parameters from the ARM Template REST response JSON. 

163 

164 Returns 

165 ------- 

166 parameters : dict 

167 Flat dictionary of parameters and their values. 

168 """ 

169 return { 

170 key: val.get("value") 

171 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() 

172 if val.get("value") is not None 

173 } 

174 

175 def _azure_rest_api_post_helper(self, params: dict, url: str) -> tuple[Status, dict]: 

176 """ 

177 General pattern for performing an action on an Azure resource via its REST API. 

178 

179 Parameters 

180 ---------- 

181 params: dict 

182 Flat dictionary of (key, value) pairs of tunable parameters. 

183 url: str 

184 REST API url for the target to perform on the Azure VM. 

185 Should be a url that we intend to POST to. 

186 

187 Returns 

188 ------- 

189 result : (Status, dict={}) 

190 A pair of Status and result. 

191 Status is one of {PENDING, SUCCEEDED, FAILED} 

192 Result will have a value for 'asyncResultsUrl' if status is PENDING, 

193 and 'pollInterval' if suggested by the API. 

194 """ 

195 _LOG.debug("Request: POST %s", url) 

196 

197 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) 

198 _LOG.debug("Response: %s", response) 

199 

200 # Logical flow for async operations based on: 

201 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations 

202 if response.status_code == 200: 

203 return (Status.SUCCEEDED, params.copy()) 

204 elif response.status_code == 202: 

205 result = params.copy() 

206 if "Azure-AsyncOperation" in response.headers: 

207 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation") 

208 elif "Location" in response.headers: 

209 result["asyncResultsUrl"] = response.headers.get("Location") 

210 if "Retry-After" in response.headers: 

211 result["pollInterval"] = float(response.headers["Retry-After"]) 

212 

213 return (Status.PENDING, result) 

214 else: 

215 _LOG.error("Response: %s :: %s", response, response.text) 

216 # _LOG.error("Bad Request:\n%s", response.request.body) 

217 return (Status.FAILED, {}) 

218 

219 def _check_operation_status(self, params: dict) -> tuple[Status, dict]: 

220 """ 

221 Checks the status of a pending operation on an Azure resource. 

222 

223 Parameters 

224 ---------- 

225 params: dict 

226 Flat dictionary of (key, value) pairs of tunable parameters. 

227 Must have the "asyncResultsUrl" key to get the results. 

228 If the key is not present, return Status.PENDING. 

229 

230 Returns 

231 ------- 

232 result : (Status, dict) 

233 A pair of Status and result. 

234 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

235 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

236 """ 

237 url = params.get("asyncResultsUrl") 

238 if url is None: 

239 return Status.PENDING, {} 

240 

241 session = self._get_session(params) 

242 try: 

243 response = session.get(url, timeout=self._request_timeout) 

244 except requests.exceptions.ReadTimeout: 

245 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

246 return Status.RUNNING, {} 

247 except requests.exceptions.RequestException as ex: 

248 _LOG.exception("Error in request checking operation status", exc_info=ex) 

249 return (Status.FAILED, {}) 

250 

251 if _LOG.isEnabledFor(logging.DEBUG): 

252 _LOG.debug( 

253 "Response: %s\n%s", 

254 response, 

255 json.dumps(response.json(), indent=2) if response.content else "", 

256 ) 

257 

258 if response.status_code == 200: 

259 output = response.json() 

260 status = output.get("status") 

261 if status == "InProgress": 

262 return Status.RUNNING, {} 

263 elif status == "Succeeded": 

264 return Status.SUCCEEDED, output 

265 

266 _LOG.error("Response: %s :: %s", response, response.text) 

267 return Status.FAILED, {} 

268 

269 def _wait_deployment(self, params: dict, *, is_setup: bool) -> tuple[Status, dict]: 

270 """ 

271 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or 

272 FAILED. Return TIMED_OUT when timing out. 

273 

274 Parameters 

275 ---------- 

276 params : dict 

277 Flat dictionary of (key, value) pairs of tunable parameters. 

278 is_setup : bool 

279 If True, wait for resource being deployed; otherwise, wait for 

280 successful deprovisioning. 

281 

282 Returns 

283 ------- 

284 result : (Status, dict) 

285 A pair of Status and result. 

286 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

287 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

288 """ 

289 params = self._set_default_params(params) 

290 _LOG.info( 

291 "Wait for %s to %s", 

292 params.get("deploymentName"), 

293 "provision" if is_setup else "deprovision", 

294 ) 

295 return self._wait_while(self._check_deployment, Status.PENDING, params) 

296 

297 def _wait_while( 

298 self, 

299 func: Callable[[dict], tuple[Status, dict]], 

300 loop_status: Status, 

301 params: dict, 

302 ) -> tuple[Status, dict]: 

303 """ 

304 Invoke `func` periodically while the status is equal to `loop_status`. Return 

305 TIMED_OUT when timing out. 

306 

307 Parameters 

308 ---------- 

309 func : a function 

310 A function that takes `params` and returns a pair of (Status, {}) 

311 loop_status: Status 

312 Steady state status - keep polling `func` while it returns `loop_status`. 

313 params : dict 

314 Flat dictionary of (key, value) pairs of tunable parameters. 

315 Requires deploymentName. 

316 

317 Returns 

318 ------- 

319 result : (Status, dict) 

320 A pair of Status and result. 

321 """ 

322 params = self._set_default_params(params) 

323 config = merge_parameters( 

324 dest=self.config.copy(), 

325 source=params, 

326 required_keys=["deploymentName"], 

327 ) 

328 

329 poll_period = params.get("pollInterval", self._poll_interval) 

330 

331 _LOG.debug( 

332 "Wait for %s status %s :: poll %.2f timeout %d s", 

333 config["deploymentName"], 

334 loop_status, 

335 poll_period, 

336 self._poll_timeout, 

337 ) 

338 

339 ts_timeout = time.time() + self._poll_timeout 

340 poll_delay = poll_period 

341 while True: 

342 # Wait for the suggested time first then check status 

343 ts_start = time.time() 

344 if ts_start >= ts_timeout: 

345 break 

346 

347 if poll_delay > 0: 

348 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period) 

349 time.sleep(poll_delay) 

350 

351 (status, output) = func(params) 

352 if status != loop_status: 

353 return status, output 

354 

355 ts_end = time.time() 

356 poll_delay = poll_period - ts_end + ts_start 

357 

358 _LOG.warning("Request timed out: %s", params) 

359 return (Status.TIMED_OUT, {}) 

360 

361 def _check_deployment(self, params: dict) -> tuple[Status, dict]: 

362 # pylint: disable=too-many-return-statements 

363 """ 

364 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. 

365 

366 Parameters 

367 ---------- 

368 _params : dict 

369 Flat dictionary of (key, value) pairs of tunable parameters. 

370 This parameter is not used; we need it for compatibility with 

371 other polling functions used in `_wait_while()`. 

372 

373 Returns 

374 ------- 

375 result : (Status, dict={}) 

376 A pair of Status and result. The result is always {}. 

377 Status is one of {SUCCEEDED, PENDING, FAILED} 

378 """ 

379 params = self._set_default_params(params) 

380 config = merge_parameters( 

381 dest=self.config.copy(), 

382 source=params, 

383 required_keys=[ 

384 "subscription", 

385 "resourceGroup", 

386 "deploymentName", 

387 ], 

388 ) 

389 

390 _LOG.info("Check deployment: %s", config["deploymentName"]) 

391 

392 url = self._URL_DEPLOY.format( 

393 subscription=config["subscription"], 

394 resource_group=config["resourceGroup"], 

395 deployment_name=config["deploymentName"], 

396 ) 

397 

398 session = self._get_session(params) 

399 try: 

400 response = session.get(url, timeout=self._request_timeout) 

401 except requests.exceptions.ReadTimeout: 

402 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

403 return Status.RUNNING, {} 

404 except requests.exceptions.RequestException as ex: 

405 _LOG.exception("Error in request checking deployment", exc_info=ex) 

406 return (Status.FAILED, {}) 

407 

408 _LOG.debug("Response: %s", response) 

409 

410 if response.status_code == 200: 

411 output = response.json() 

412 state = output.get("properties", {}).get("provisioningState", "") 

413 

414 if state == "Succeeded": 

415 return (Status.SUCCEEDED, {}) 

416 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}: 

417 return (Status.PENDING, {}) 

418 else: 

419 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2)) 

420 return (Status.FAILED, {}) 

421 elif response.status_code == 404: 

422 return (Status.PENDING, {}) 

423 

424 _LOG.error("Response: %s :: %s", response, response.text) 

425 return (Status.FAILED, {}) 

426 

427 def _provision_resource(self, params: dict) -> tuple[Status, dict]: 

428 """ 

429 Attempts to (re)deploy a resource. 

430 

431 Parameters 

432 ---------- 

433 params : dict 

434 Flat dictionary of (key, value) pairs of tunable parameters. 

435 Tunables are variable parameters that, together with the 

436 Environment configuration, are sufficient to provision the resource. 

437 

438 Returns 

439 ------- 

440 result : (Status, dict={}) 

441 A pair of Status and result. The result is the input `params` plus the 

442 parameters extracted from the response JSON, or {} if the status is FAILED. 

443 Status is one of {PENDING, SUCCEEDED, FAILED} 

444 """ 

445 if not self._deploy_template: 

446 raise ValueError(f"Missing deployment template: {self}") 

447 params = self._set_default_params(params) 

448 config = merge_parameters( 

449 dest=self.config.copy(), 

450 source=params, 

451 required_keys=["deploymentName"], 

452 ) 

453 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) 

454 

455 params = merge_parameters(dest=self._deploy_params.copy(), source=params) 

456 if _LOG.isEnabledFor(logging.DEBUG): 

457 _LOG.debug( 

458 "Deploy: %s merged params ::\n%s", 

459 config["deploymentName"], 

460 json.dumps(params, indent=2), 

461 ) 

462 

463 url = self._URL_DEPLOY.format( 

464 subscription=config["subscription"], 

465 resource_group=config["resourceGroup"], 

466 deployment_name=config["deploymentName"], 

467 ) 

468 

469 json_req = { 

470 "properties": { 

471 "mode": "Incremental", 

472 "template": self._deploy_template, 

473 "parameters": { 

474 key: {"value": val} 

475 for (key, val) in params.items() 

476 if key in self._deploy_template.get("parameters", {}) 

477 }, 

478 } 

479 } 

480 

481 if _LOG.isEnabledFor(logging.DEBUG): 

482 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

483 

484 response = requests.put( 

485 url, 

486 json=json_req, 

487 headers=self._get_headers(), 

488 timeout=self._request_timeout, 

489 ) 

490 

491 if _LOG.isEnabledFor(logging.DEBUG): 

492 _LOG.debug( 

493 "Response: %s\n%s", 

494 response, 

495 json.dumps(response.json(), indent=2) if response.content else "", 

496 ) 

497 else: 

498 _LOG.info("Response: %s", response) 

499 

500 if response.status_code == 200: 

501 return (Status.PENDING, config) 

502 elif response.status_code == 201: 

503 output = self._extract_arm_parameters(response.json()) 

504 if _LOG.isEnabledFor(logging.DEBUG): 

505 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2)) 

506 params.update(output) 

507 params.setdefault("asyncResultsUrl", url) 

508 params.setdefault("deploymentName", config["deploymentName"]) 

509 return (Status.PENDING, params) 

510 else: 

511 _LOG.error("Response: %s :: %s", response, response.text) 

512 # _LOG.error("Bad Request:\n%s", response.request.body) 

513 return (Status.FAILED, {})