Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for certain Azure Services classes that do deployments."""
7import abc
8import json
9import logging
10import time
11from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13import requests
14from requests.adapters import HTTPAdapter, Retry
16from mlos_bench.dict_templater import DictTemplater
17from mlos_bench.environments.status import Status
18from mlos_bench.services.base_service import Service
19from mlos_bench.services.types.authenticator_type import SupportsAuth
20from mlos_bench.util import check_required_params, merge_parameters
22_LOG = logging.getLogger(__name__)
25class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
26 """Helper methods to manage and deploy Azure resources via REST APIs."""
28 _POLL_INTERVAL = 4 # seconds
29 _POLL_TIMEOUT = 300 # seconds
30 _REQUEST_TIMEOUT = 5 # seconds
31 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
32 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
33 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3
35 # Azure Resources Deployment REST API as described in
36 # https://docs.microsoft.com/en-us/rest/api/resources/deployments
38 _URL_DEPLOY = (
39 "https://management.azure.com"
40 "/subscriptions/{subscription}"
41 "/resourceGroups/{resource_group}"
42 "/providers/Microsoft.Resources"
43 "/deployments/{deployment_name}"
44 "?api-version=2022-05-01"
45 )
47 def __init__(
48 self,
49 config: Optional[Dict[str, Any]] = None,
50 global_config: Optional[Dict[str, Any]] = None,
51 parent: Optional[Service] = None,
52 methods: Union[Dict[str, Callable], List[Callable], None] = None,
53 ):
54 """
55 Create a new instance of an Azure Services proxy.
57 Parameters
58 ----------
59 config : dict
60 Free-format dictionary that contains the benchmark environment
61 configuration.
62 global_config : dict
63 Free-format dictionary of global parameters.
64 parent : Service
65 Parent service that can provide mixin functions.
66 methods : Union[Dict[str, Callable], List[Callable], None]
67 New methods to register with the service.
68 """
69 super().__init__(config, global_config, parent, methods)
71 check_required_params(
72 self.config,
73 [
74 "subscription",
75 "resourceGroup",
76 ],
77 )
79 # These parameters can come from command line as strings, so conversion is needed.
80 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
81 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
82 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
83 self._total_retries = int(
84 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
85 )
86 self._backoff_factor = float(
87 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
88 )
90 self._deploy_template = {}
91 self._deploy_params = {}
92 if self.config.get("deploymentTemplatePath") is not None:
93 # TODO: Provide external schema validation?
94 template = self.config_loader_service.load_config(
95 self.config["deploymentTemplatePath"],
96 schema_type=None,
97 )
98 assert template is not None and isinstance(template, dict)
99 self._deploy_template = template
101 # Allow for recursive variable expansion as we do with global params and const_args.
102 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
103 extra_source_dict=global_config
104 )
105 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
106 else:
107 _LOG.info(
108 "No deploymentTemplatePath provided. Deployment services will be unavailable.",
109 )
111 @property
112 def deploy_params(self) -> dict:
113 """Get the deployment parameters."""
114 return self._deploy_params
116 @abc.abstractmethod
117 def _set_default_params(self, params: dict) -> dict:
118 """
119 Optionally set some default parameters for the request.
121 Parameters
122 ----------
123 params : dict
124 The parameters.
126 Returns
127 -------
128 dict
129 The updated parameters.
130 """
131 raise NotImplementedError("Should be overridden by subclass.")
133 def _get_session(self, params: dict) -> requests.Session:
134 """Get a session object that includes automatic retries and headers for REST API
135 calls.
136 """
137 total_retries = params.get("requestTotalRetries", self._total_retries)
138 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
139 session = requests.Session()
140 session.mount(
141 "https://",
142 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)),
143 )
144 session.headers.update(self._get_headers())
145 return session
147 def _get_headers(self) -> dict:
148 """Get the headers for the REST API calls."""
149 assert self._parent is not None and isinstance(
150 self._parent, SupportsAuth
151 ), "Authorization service not provided. Include service-auth.jsonc?"
152 return self._parent.get_auth_headers()
154 @staticmethod
155 def _extract_arm_parameters(json_data: dict) -> dict:
156 """
157 Extract parameters from the ARM Template REST response JSON.
159 Returns
160 -------
161 parameters : dict
162 Flat dictionary of parameters and their values.
163 """
164 return {
165 key: val.get("value")
166 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items()
167 if val.get("value") is not None
168 }
170 def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]:
171 """
172 General pattern for performing an action on an Azure resource via its REST API.
174 Parameters
175 ----------
176 params: dict
177 Flat dictionary of (key, value) pairs of tunable parameters.
178 url: str
179 REST API url for the target to perform on the Azure VM.
180 Should be a url that we intend to POST to.
182 Returns
183 -------
184 result : (Status, dict={})
185 A pair of Status and result.
186 Status is one of {PENDING, SUCCEEDED, FAILED}
187 Result will have a value for 'asyncResultsUrl' if status is PENDING,
188 and 'pollInterval' if suggested by the API.
189 """
190 _LOG.debug("Request: POST %s", url)
192 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout)
193 _LOG.debug("Response: %s", response)
195 # Logical flow for async operations based on:
196 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations
197 if response.status_code == 200:
198 return (Status.SUCCEEDED, params.copy())
199 elif response.status_code == 202:
200 result = params.copy()
201 if "Azure-AsyncOperation" in response.headers:
202 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation")
203 elif "Location" in response.headers:
204 result["asyncResultsUrl"] = response.headers.get("Location")
205 if "Retry-After" in response.headers:
206 result["pollInterval"] = float(response.headers["Retry-After"])
208 return (Status.PENDING, result)
209 else:
210 _LOG.error("Response: %s :: %s", response, response.text)
211 # _LOG.error("Bad Request:\n%s", response.request.body)
212 return (Status.FAILED, {})
214 def _check_operation_status(self, params: dict) -> Tuple[Status, dict]:
215 """
216 Checks the status of a pending operation on an Azure resource.
218 Parameters
219 ----------
220 params: dict
221 Flat dictionary of (key, value) pairs of tunable parameters.
222 Must have the "asyncResultsUrl" key to get the results.
223 If the key is not present, return Status.PENDING.
225 Returns
226 -------
227 result : (Status, dict)
228 A pair of Status and result.
229 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED}
230 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
231 """
232 url = params.get("asyncResultsUrl")
233 if url is None:
234 return Status.PENDING, {}
236 session = self._get_session(params)
237 try:
238 response = session.get(url, timeout=self._request_timeout)
239 except requests.exceptions.ReadTimeout:
240 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
241 return Status.RUNNING, {}
242 except requests.exceptions.RequestException as ex:
243 _LOG.exception("Error in request checking operation status", exc_info=ex)
244 return (Status.FAILED, {})
246 if _LOG.isEnabledFor(logging.DEBUG):
247 _LOG.debug(
248 "Response: %s\n%s",
249 response,
250 json.dumps(response.json(), indent=2) if response.content else "",
251 )
253 if response.status_code == 200:
254 output = response.json()
255 status = output.get("status")
256 if status == "InProgress":
257 return Status.RUNNING, {}
258 elif status == "Succeeded":
259 return Status.SUCCEEDED, output
261 _LOG.error("Response: %s :: %s", response, response.text)
262 return Status.FAILED, {}
264 def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
265 """
266 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
267 FAILED. Return TIMED_OUT when timing out.
269 Parameters
270 ----------
271 params : dict
272 Flat dictionary of (key, value) pairs of tunable parameters.
273 is_setup : bool
274 If True, wait for resource being deployed; otherwise, wait for
275 successful deprovisioning.
277 Returns
278 -------
279 result : (Status, dict)
280 A pair of Status and result.
281 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
282 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
283 """
284 params = self._set_default_params(params)
285 _LOG.info(
286 "Wait for %s to %s",
287 params.get("deploymentName"),
288 "provision" if is_setup else "deprovision",
289 )
290 return self._wait_while(self._check_deployment, Status.PENDING, params)
292 def _wait_while(
293 self,
294 func: Callable[[dict], Tuple[Status, dict]],
295 loop_status: Status,
296 params: dict,
297 ) -> Tuple[Status, dict]:
298 """
299 Invoke `func` periodically while the status is equal to `loop_status`. Return
300 TIMED_OUT when timing out.
302 Parameters
303 ----------
304 func : a function
305 A function that takes `params` and returns a pair of (Status, {})
306 loop_status: Status
307 Steady state status - keep polling `func` while it returns `loop_status`.
308 params : dict
309 Flat dictionary of (key, value) pairs of tunable parameters.
310 Requires deploymentName.
312 Returns
313 -------
314 result : (Status, dict)
315 A pair of Status and result.
316 """
317 params = self._set_default_params(params)
318 config = merge_parameters(
319 dest=self.config.copy(),
320 source=params,
321 required_keys=["deploymentName"],
322 )
324 poll_period = params.get("pollInterval", self._poll_interval)
326 _LOG.debug(
327 "Wait for %s status %s :: poll %.2f timeout %d s",
328 config["deploymentName"],
329 loop_status,
330 poll_period,
331 self._poll_timeout,
332 )
334 ts_timeout = time.time() + self._poll_timeout
335 poll_delay = poll_period
336 while True:
337 # Wait for the suggested time first then check status
338 ts_start = time.time()
339 if ts_start >= ts_timeout:
340 break
342 if poll_delay > 0:
343 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period)
344 time.sleep(poll_delay)
346 (status, output) = func(params)
347 if status != loop_status:
348 return status, output
350 ts_end = time.time()
351 poll_delay = poll_period - ts_end + ts_start
353 _LOG.warning("Request timed out: %s", params)
354 return (Status.TIMED_OUT, {})
356 def _check_deployment(self, params: dict) -> Tuple[Status, dict]:
357 # pylint: disable=too-many-return-statements
358 """
359 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
361 Parameters
362 ----------
363 _params : dict
364 Flat dictionary of (key, value) pairs of tunable parameters.
365 This parameter is not used; we need it for compatibility with
366 other polling functions used in `_wait_while()`.
368 Returns
369 -------
370 result : (Status, dict={})
371 A pair of Status and result. The result is always {}.
372 Status is one of {SUCCEEDED, PENDING, FAILED}
373 """
374 params = self._set_default_params(params)
375 config = merge_parameters(
376 dest=self.config.copy(),
377 source=params,
378 required_keys=[
379 "subscription",
380 "resourceGroup",
381 "deploymentName",
382 ],
383 )
385 _LOG.info("Check deployment: %s", config["deploymentName"])
387 url = self._URL_DEPLOY.format(
388 subscription=config["subscription"],
389 resource_group=config["resourceGroup"],
390 deployment_name=config["deploymentName"],
391 )
393 session = self._get_session(params)
394 try:
395 response = session.get(url, timeout=self._request_timeout)
396 except requests.exceptions.ReadTimeout:
397 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
398 return Status.RUNNING, {}
399 except requests.exceptions.RequestException as ex:
400 _LOG.exception("Error in request checking deployment", exc_info=ex)
401 return (Status.FAILED, {})
403 _LOG.debug("Response: %s", response)
405 if response.status_code == 200:
406 output = response.json()
407 state = output.get("properties", {}).get("provisioningState", "")
409 if state == "Succeeded":
410 return (Status.SUCCEEDED, {})
411 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}:
412 return (Status.PENDING, {})
413 else:
414 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2))
415 return (Status.FAILED, {})
416 elif response.status_code == 404:
417 return (Status.PENDING, {})
419 _LOG.error("Response: %s :: %s", response, response.text)
420 return (Status.FAILED, {})
422 def _provision_resource(self, params: dict) -> Tuple[Status, dict]:
423 """
424 Attempts to (re)deploy a resource.
426 Parameters
427 ----------
428 params : dict
429 Flat dictionary of (key, value) pairs of tunable parameters.
430 Tunables are variable parameters that, together with the
431 Environment configuration, are sufficient to provision the resource.
433 Returns
434 -------
435 result : (Status, dict={})
436 A pair of Status and result. The result is the input `params` plus the
437 parameters extracted from the response JSON, or {} if the status is FAILED.
438 Status is one of {PENDING, SUCCEEDED, FAILED}
439 """
440 if not self._deploy_template:
441 raise ValueError(f"Missing deployment template: {self}")
442 params = self._set_default_params(params)
443 config = merge_parameters(
444 dest=self.config.copy(),
445 source=params,
446 required_keys=["deploymentName"],
447 )
448 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
450 params = merge_parameters(dest=self._deploy_params.copy(), source=params)
451 if _LOG.isEnabledFor(logging.DEBUG):
452 _LOG.debug(
453 "Deploy: %s merged params ::\n%s",
454 config["deploymentName"],
455 json.dumps(params, indent=2),
456 )
458 url = self._URL_DEPLOY.format(
459 subscription=config["subscription"],
460 resource_group=config["resourceGroup"],
461 deployment_name=config["deploymentName"],
462 )
464 json_req = {
465 "properties": {
466 "mode": "Incremental",
467 "template": self._deploy_template,
468 "parameters": {
469 key: {"value": val}
470 for (key, val) in params.items()
471 if key in self._deploy_template.get("parameters", {})
472 },
473 }
474 }
476 if _LOG.isEnabledFor(logging.DEBUG):
477 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
479 response = requests.put(
480 url,
481 json=json_req,
482 headers=self._get_headers(),
483 timeout=self._request_timeout,
484 )
486 if _LOG.isEnabledFor(logging.DEBUG):
487 _LOG.debug(
488 "Response: %s\n%s",
489 response,
490 json.dumps(response.json(), indent=2) if response.content else "",
491 )
492 else:
493 _LOG.info("Response: %s", response)
495 if response.status_code == 200:
496 return (Status.PENDING, config)
497 elif response.status_code == 201:
498 output = self._extract_arm_parameters(response.json())
499 if _LOG.isEnabledFor(logging.DEBUG):
500 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2))
501 params.update(output)
502 params.setdefault("asyncResultsUrl", url)
503 params.setdefault("deploymentName", config["deploymentName"])
504 return (Status.PENDING, params)
505 else:
506 _LOG.error("Response: %s :: %s", response, response.text)
507 # _LOG.error("Bad Request:\n%s", response.request.body)
508 return (Status.FAILED, {})