Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%
179 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-01 00:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for certain Azure Services classes that do deployments."""
7import abc
8import json
9import logging
10import time
11from collections.abc import Callable
12from typing import Any
14import requests
15from requests.adapters import HTTPAdapter, Retry
17from mlos_bench.dict_templater import DictTemplater
18from mlos_bench.environments.status import Status
19from mlos_bench.services.base_service import Service
20from mlos_bench.services.types.authenticator_type import SupportsAuth
21from mlos_bench.util import check_required_params, merge_parameters
23_LOG = logging.getLogger(__name__)
26class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
27 """Helper methods to manage and deploy Azure resources via REST APIs."""
29 _POLL_INTERVAL = 4 # seconds
30 _POLL_TIMEOUT = 300 # seconds
31 _REQUEST_TIMEOUT = 5 # seconds
32 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
33 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
34 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3
36 # Azure Resources Deployment REST API as described in
37 # https://docs.microsoft.com/en-us/rest/api/resources/deployments
39 _URL_DEPLOY = (
40 "https://management.azure.com"
41 "/subscriptions/{subscription}"
42 "/resourceGroups/{resource_group}"
43 "/providers/Microsoft.Resources"
44 "/deployments/{deployment_name}"
45 "?api-version=2022-05-01"
46 )
48 def __init__(
49 self,
50 config: dict[str, Any] | None = None,
51 global_config: dict[str, Any] | None = None,
52 parent: Service | None = None,
53 methods: dict[str, Callable] | list[Callable] | None = None,
54 ):
55 """
56 Create a new instance of an Azure Services proxy.
58 Parameters
59 ----------
60 config : dict
61 Free-format dictionary that contains the benchmark environment
62 configuration.
63 global_config : dict
64 Free-format dictionary of global parameters.
65 parent : Service
66 Parent service that can provide mixin functions.
67 methods : Union[dict[str, Callable], list[Callable], None]
68 New methods to register with the service.
69 """
70 super().__init__(config, global_config, parent, methods)
72 check_required_params(
73 self.config,
74 [
75 "subscription",
76 "resourceGroup",
77 ],
78 )
80 # These parameters can come from command line as strings, so conversion is needed.
81 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
82 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
83 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
84 self._total_retries = int(
85 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
86 )
87 self._backoff_factor = float(
88 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
89 )
91 self._deploy_template = {}
92 self._deploy_params = {}
93 if self.config.get("deploymentTemplatePath") is not None:
94 # TODO: Provide external schema validation?
95 template = self.config_loader_service.load_config(
96 self.config["deploymentTemplatePath"],
97 schema_type=None,
98 )
99 assert template is not None and isinstance(template, dict)
100 self._deploy_template = template
102 # Allow for recursive variable expansion as we do with global params and const_args.
103 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
104 extra_source_dict=global_config
105 )
106 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
107 else:
108 _LOG.info(
109 "No deploymentTemplatePath provided. Deployment services will be unavailable.",
110 )
112 @property
113 def deploy_params(self) -> dict:
114 """Get the deployment parameters."""
115 return self._deploy_params
117 @abc.abstractmethod
118 def _set_default_params(self, params: dict) -> dict:
119 """
120 Optionally set some default parameters for the request.
122 Parameters
123 ----------
124 params : dict
125 The parameters.
127 Returns
128 -------
129 dict
130 The updated parameters.
131 """
132 raise NotImplementedError("Should be overridden by subclass.")
134 def _get_session(self, params: dict) -> requests.Session:
135 """Get a session object that includes automatic retries and headers for REST API
136 calls.
137 """
138 total_retries = params.get("requestTotalRetries", self._total_retries)
139 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
140 session = requests.Session()
141 session.mount(
142 "https://",
143 HTTPAdapter(
144 max_retries=Retry(
145 total=total_retries, backoff_factor=backoff_factor, status_forcelist=[503]
146 )
147 ),
148 )
149 session.headers.update(self._get_headers())
150 return session
152 def _get_headers(self) -> dict:
153 """Get the headers for the REST API calls."""
154 assert self._parent is not None and isinstance(
155 self._parent, SupportsAuth
156 ), "Authorization service not provided. Include service-auth.jsonc?"
157 return self._parent.get_auth_headers()
159 @staticmethod
160 def _extract_arm_parameters(json_data: dict) -> dict:
161 """
162 Extract parameters from the ARM Template REST response JSON.
164 Returns
165 -------
166 parameters : dict
167 Flat dictionary of parameters and their values.
168 """
169 return {
170 key: val.get("value")
171 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items()
172 if val.get("value") is not None
173 }
175 def _azure_rest_api_post_helper(self, params: dict, url: str) -> tuple[Status, dict]:
176 """
177 General pattern for performing an action on an Azure resource via its REST API.
179 Parameters
180 ----------
181 params: dict
182 Flat dictionary of (key, value) pairs of tunable parameters.
183 url: str
184 REST API url for the target to perform on the Azure VM.
185 Should be a url that we intend to POST to.
187 Returns
188 -------
189 result : (Status, dict={})
190 A pair of Status and result.
191 Status is one of {PENDING, SUCCEEDED, FAILED}
192 Result will have a value for 'asyncResultsUrl' if status is PENDING,
193 and 'pollInterval' if suggested by the API.
194 """
195 _LOG.debug("Request: POST %s", url)
197 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout)
198 _LOG.debug("Response: %s", response)
200 # Logical flow for async operations based on:
201 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations
202 if response.status_code == 200:
203 return (Status.SUCCEEDED, params.copy())
204 elif response.status_code == 202:
205 result = params.copy()
206 if "Azure-AsyncOperation" in response.headers:
207 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation")
208 elif "Location" in response.headers:
209 result["asyncResultsUrl"] = response.headers.get("Location")
210 if "Retry-After" in response.headers:
211 result["pollInterval"] = float(response.headers["Retry-After"])
213 return (Status.PENDING, result)
214 else:
215 _LOG.error("Response: %s :: %s", response, response.text)
216 # _LOG.error("Bad Request:\n%s", response.request.body)
217 return (Status.FAILED, {})
219 def _check_operation_status(self, params: dict) -> tuple[Status, dict]:
220 """
221 Checks the status of a pending operation on an Azure resource.
223 Parameters
224 ----------
225 params: dict
226 Flat dictionary of (key, value) pairs of tunable parameters.
227 Must have the "asyncResultsUrl" key to get the results.
228 If the key is not present, return Status.PENDING.
230 Returns
231 -------
232 result : (Status, dict)
233 A pair of Status and result.
234 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED}
235 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
236 """
237 url = params.get("asyncResultsUrl")
238 if url is None:
239 return Status.PENDING, {}
241 session = self._get_session(params)
242 try:
243 response = session.get(url, timeout=self._request_timeout)
244 except requests.exceptions.ReadTimeout:
245 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
246 return Status.RUNNING, {}
247 except requests.exceptions.RequestException as ex:
248 _LOG.exception("Error in request checking operation status", exc_info=ex)
249 return (Status.FAILED, {})
251 if _LOG.isEnabledFor(logging.DEBUG):
252 _LOG.debug(
253 "Response: %s\n%s",
254 response,
255 json.dumps(response.json(), indent=2) if response.content else "",
256 )
258 if response.status_code == 200:
259 output = response.json()
260 status = output.get("status")
261 if status == "InProgress":
262 return Status.RUNNING, {}
263 elif status == "Succeeded":
264 return Status.SUCCEEDED, output
266 _LOG.error("Response: %s :: %s", response, response.text)
267 return Status.FAILED, {}
269 def _wait_deployment(self, params: dict, *, is_setup: bool) -> tuple[Status, dict]:
270 """
271 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
272 FAILED. Return TIMED_OUT when timing out.
274 Parameters
275 ----------
276 params : dict
277 Flat dictionary of (key, value) pairs of tunable parameters.
278 is_setup : bool
279 If True, wait for resource being deployed; otherwise, wait for
280 successful deprovisioning.
282 Returns
283 -------
284 result : (Status, dict)
285 A pair of Status and result.
286 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
287 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
288 """
289 params = self._set_default_params(params)
290 _LOG.info(
291 "Wait for %s to %s",
292 params.get("deploymentName"),
293 "provision" if is_setup else "deprovision",
294 )
295 return self._wait_while(self._check_deployment, Status.PENDING, params)
297 def _wait_while(
298 self,
299 func: Callable[[dict], tuple[Status, dict]],
300 loop_status: Status,
301 params: dict,
302 ) -> tuple[Status, dict]:
303 """
304 Invoke `func` periodically while the status is equal to `loop_status`. Return
305 TIMED_OUT when timing out.
307 Parameters
308 ----------
309 func : a function
310 A function that takes `params` and returns a pair of (Status, {})
311 loop_status: Status
312 Steady state status - keep polling `func` while it returns `loop_status`.
313 params : dict
314 Flat dictionary of (key, value) pairs of tunable parameters.
315 Requires deploymentName.
317 Returns
318 -------
319 result : (Status, dict)
320 A pair of Status and result.
321 """
322 params = self._set_default_params(params)
323 config = merge_parameters(
324 dest=self.config.copy(),
325 source=params,
326 required_keys=["deploymentName"],
327 )
329 poll_period = params.get("pollInterval", self._poll_interval)
331 _LOG.debug(
332 "Wait for %s status %s :: poll %.2f timeout %d s",
333 config["deploymentName"],
334 loop_status,
335 poll_period,
336 self._poll_timeout,
337 )
339 ts_timeout = time.time() + self._poll_timeout
340 poll_delay = poll_period
341 while True:
342 # Wait for the suggested time first then check status
343 ts_start = time.time()
344 if ts_start >= ts_timeout:
345 break
347 if poll_delay > 0:
348 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period)
349 time.sleep(poll_delay)
351 (status, output) = func(params)
352 if status != loop_status:
353 return status, output
355 ts_end = time.time()
356 poll_delay = poll_period - ts_end + ts_start
358 _LOG.warning("Request timed out: %s", params)
359 return (Status.TIMED_OUT, {})
361 def _check_deployment(self, params: dict) -> tuple[Status, dict]:
362 # pylint: disable=too-many-return-statements
363 """
364 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
366 Parameters
367 ----------
368 _params : dict
369 Flat dictionary of (key, value) pairs of tunable parameters.
370 This parameter is not used; we need it for compatibility with
371 other polling functions used in `_wait_while()`.
373 Returns
374 -------
375 result : (Status, dict={})
376 A pair of Status and result. The result is always {}.
377 Status is one of {SUCCEEDED, PENDING, FAILED}
378 """
379 params = self._set_default_params(params)
380 config = merge_parameters(
381 dest=self.config.copy(),
382 source=params,
383 required_keys=[
384 "subscription",
385 "resourceGroup",
386 "deploymentName",
387 ],
388 )
390 _LOG.info("Check deployment: %s", config["deploymentName"])
392 url = self._URL_DEPLOY.format(
393 subscription=config["subscription"],
394 resource_group=config["resourceGroup"],
395 deployment_name=config["deploymentName"],
396 )
398 session = self._get_session(params)
399 try:
400 response = session.get(url, timeout=self._request_timeout)
401 except requests.exceptions.ReadTimeout:
402 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
403 return Status.RUNNING, {}
404 except requests.exceptions.RequestException as ex:
405 _LOG.exception("Error in request checking deployment", exc_info=ex)
406 return (Status.FAILED, {})
408 _LOG.debug("Response: %s", response)
410 if response.status_code == 200:
411 output = response.json()
412 state = output.get("properties", {}).get("provisioningState", "")
414 if state == "Succeeded":
415 return (Status.SUCCEEDED, {})
416 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}:
417 return (Status.PENDING, {})
418 else:
419 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2))
420 return (Status.FAILED, {})
421 elif response.status_code == 404:
422 return (Status.PENDING, {})
424 _LOG.error("Response: %s :: %s", response, response.text)
425 return (Status.FAILED, {})
427 def _provision_resource(self, params: dict) -> tuple[Status, dict]:
428 """
429 Attempts to (re)deploy a resource.
431 Parameters
432 ----------
433 params : dict
434 Flat dictionary of (key, value) pairs of tunable parameters.
435 Tunables are variable parameters that, together with the
436 Environment configuration, are sufficient to provision the resource.
438 Returns
439 -------
440 result : (Status, dict={})
441 A pair of Status and result. The result is the input `params` plus the
442 parameters extracted from the response JSON, or {} if the status is FAILED.
443 Status is one of {PENDING, SUCCEEDED, FAILED}
444 """
445 if not self._deploy_template:
446 raise ValueError(f"Missing deployment template: {self}")
447 params = self._set_default_params(params)
448 config = merge_parameters(
449 dest=self.config.copy(),
450 source=params,
451 required_keys=["deploymentName"],
452 )
453 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
455 params = merge_parameters(dest=self._deploy_params.copy(), source=params)
456 if _LOG.isEnabledFor(logging.DEBUG):
457 _LOG.debug(
458 "Deploy: %s merged params ::\n%s",
459 config["deploymentName"],
460 json.dumps(params, indent=2),
461 )
463 url = self._URL_DEPLOY.format(
464 subscription=config["subscription"],
465 resource_group=config["resourceGroup"],
466 deployment_name=config["deploymentName"],
467 )
469 json_req = {
470 "properties": {
471 "mode": "Incremental",
472 "template": self._deploy_template,
473 "parameters": {
474 key: {"value": val}
475 for (key, val) in params.items()
476 if key in self._deploy_template.get("parameters", {})
477 },
478 }
479 }
481 if _LOG.isEnabledFor(logging.DEBUG):
482 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
484 response = requests.put(
485 url,
486 json=json_req,
487 headers=self._get_headers(),
488 timeout=self._request_timeout,
489 )
491 if _LOG.isEnabledFor(logging.DEBUG):
492 _LOG.debug(
493 "Response: %s\n%s",
494 response,
495 json.dumps(response.json(), indent=2) if response.content else "",
496 )
497 else:
498 _LOG.info("Response: %s", response)
500 if response.status_code == 200:
501 return (Status.PENDING, config)
502 elif response.status_code == 201:
503 output = self._extract_arm_parameters(response.json())
504 if _LOG.isEnabledFor(logging.DEBUG):
505 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2))
506 params.update(output)
507 params.setdefault("asyncResultsUrl", url)
508 params.setdefault("deploymentName", config["deploymentName"])
509 return (Status.PENDING, params)
510 else:
511 _LOG.error("Response: %s :: %s", response, response.text)
512 # _LOG.error("Bad Request:\n%s", response.request.body)
513 return (Status.FAILED, {})