Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%
179 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base class for certain Azure Services classes that do deployments.
7"""
9import abc
10import json
11import time
12import logging
14from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16import requests
17from requests.adapters import HTTPAdapter, Retry
19from mlos_bench.dict_templater import DictTemplater
20from mlos_bench.environments.status import Status
21from mlos_bench.services.base_service import Service
22from mlos_bench.services.types.authenticator_type import SupportsAuth
23from mlos_bench.util import check_required_params, merge_parameters
25_LOG = logging.getLogger(__name__)
28class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
29 """
30 Helper methods to manage and deploy Azure resources via REST APIs.
31 """
33 _POLL_INTERVAL = 4 # seconds
34 _POLL_TIMEOUT = 300 # seconds
35 _REQUEST_TIMEOUT = 5 # seconds
36 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
37 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
39 # Azure Resources Deployment REST API as described in
40 # https://docs.microsoft.com/en-us/rest/api/resources/deployments
42 _URL_DEPLOY = (
43 "https://management.azure.com" +
44 "/subscriptions/{subscription}" +
45 "/resourceGroups/{resource_group}" +
46 "/providers/Microsoft.Resources" +
47 "/deployments/{deployment_name}" +
48 "?api-version=2022-05-01"
49 )
51 def __init__(self,
52 config: Optional[Dict[str, Any]] = None,
53 global_config: Optional[Dict[str, Any]] = None,
54 parent: Optional[Service] = None,
55 methods: Union[Dict[str, Callable], List[Callable], None] = None):
56 """
57 Create a new instance of an Azure Services proxy.
59 Parameters
60 ----------
61 config : dict
62 Free-format dictionary that contains the benchmark environment
63 configuration.
64 global_config : dict
65 Free-format dictionary of global parameters.
66 parent : Service
67 Parent service that can provide mixin functions.
68 methods : Union[Dict[str, Callable], List[Callable], None]
69 New methods to register with the service.
70 """
71 super().__init__(config, global_config, parent, methods)
73 check_required_params(self.config, [
74 "subscription",
75 "resourceGroup",
76 ])
78 # These parameters can come from command line as strings, so conversion is needed.
79 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
80 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
81 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
82 self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES))
83 self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR))
85 self._deploy_template = {}
86 self._deploy_params = {}
87 if self.config.get("deploymentTemplatePath") is not None:
88 # TODO: Provide external schema validation?
89 template = self.config_loader_service.load_config(
90 self.config['deploymentTemplatePath'], schema_type=None)
91 assert template is not None and isinstance(template, dict)
92 self._deploy_template = template
94 # Allow for recursive variable expansion as we do with global params and const_args.
95 deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config)
96 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
97 else:
98 _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.")
100 @property
101 def deploy_params(self) -> dict:
102 """
103 Get the deployment parameters.
104 """
105 return self._deploy_params
107 @abc.abstractmethod
108 def _set_default_params(self, params: dict) -> dict:
109 """
110 Optionally set some default parameters for the request.
112 Parameters
113 ----------
114 params : dict
115 The parameters.
117 Returns
118 -------
119 dict
120 The updated parameters.
121 """
122 raise NotImplementedError("Should be overridden by subclass.")
124 def _get_session(self, params: dict) -> requests.Session:
125 """
126 Get a session object that includes automatic retries and headers for REST API calls.
127 """
128 total_retries = params.get("requestTotalRetries", self._total_retries)
129 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
130 session = requests.Session()
131 session.mount(
132 "https://",
133 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)))
134 session.headers.update(self._get_headers())
135 return session
137 def _get_headers(self) -> dict:
138 """
139 Get the headers for the REST API calls.
140 """
141 assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
142 "Authorization service not provided. Include service-auth.jsonc?"
143 return self._parent.get_auth_headers()
145 @staticmethod
146 def _extract_arm_parameters(json_data: dict) -> dict:
147 """
148 Extract parameters from the ARM Template REST response JSON.
150 Returns
151 -------
152 parameters : dict
153 Flat dictionary of parameters and their values.
154 """
155 return {
156 key: val.get("value")
157 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items()
158 if val.get("value") is not None
159 }
161 def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]:
162 """
163 General pattern for performing an action on an Azure resource via its REST API.
165 Parameters
166 ----------
167 params: dict
168 Flat dictionary of (key, value) pairs of tunable parameters.
169 url: str
170 REST API url for the target to perform on the Azure VM.
171 Should be a url that we intend to POST to.
173 Returns
174 -------
175 result : (Status, dict={})
176 A pair of Status and result.
177 Status is one of {PENDING, SUCCEEDED, FAILED}
178 Result will have a value for 'asyncResultsUrl' if status is PENDING,
179 and 'pollInterval' if suggested by the API.
180 """
181 _LOG.debug("Request: POST %s", url)
183 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout)
184 _LOG.debug("Response: %s", response)
186 # Logical flow for async operations based on:
187 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations
188 if response.status_code == 200:
189 return (Status.SUCCEEDED, params.copy())
190 elif response.status_code == 202:
191 result = params.copy()
192 if "Azure-AsyncOperation" in response.headers:
193 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation")
194 elif "Location" in response.headers:
195 result["asyncResultsUrl"] = response.headers.get("Location")
196 if "Retry-After" in response.headers:
197 result["pollInterval"] = float(response.headers["Retry-After"])
199 return (Status.PENDING, result)
200 else:
201 _LOG.error("Response: %s :: %s", response, response.text)
202 # _LOG.error("Bad Request:\n%s", response.request.body)
203 return (Status.FAILED, {})
205 def _check_operation_status(self, params: dict) -> Tuple[Status, dict]:
206 """
207 Checks the status of a pending operation on an Azure resource.
209 Parameters
210 ----------
211 params: dict
212 Flat dictionary of (key, value) pairs of tunable parameters.
213 Must have the "asyncResultsUrl" key to get the results.
214 If the key is not present, return Status.PENDING.
216 Returns
217 -------
218 result : (Status, dict)
219 A pair of Status and result.
220 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED}
221 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
222 """
223 url = params.get("asyncResultsUrl")
224 if url is None:
225 return Status.PENDING, {}
227 session = self._get_session(params)
228 try:
229 response = session.get(url, timeout=self._request_timeout)
230 except requests.exceptions.ReadTimeout:
231 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
232 return Status.RUNNING, {}
233 except requests.exceptions.RequestException as ex:
234 _LOG.exception("Error in request checking operation status", exc_info=ex)
235 return (Status.FAILED, {})
237 if _LOG.isEnabledFor(logging.DEBUG):
238 _LOG.debug("Response: %s\n%s", response,
239 json.dumps(response.json(), indent=2)
240 if response.content else "")
242 if response.status_code == 200:
243 output = response.json()
244 status = output.get("status")
245 if status == "InProgress":
246 return Status.RUNNING, {}
247 elif status == "Succeeded":
248 return Status.SUCCEEDED, output
250 _LOG.error("Response: %s :: %s", response, response.text)
251 return Status.FAILED, {}
253 def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]:
254 """
255 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED.
256 Return TIMED_OUT when timing out.
258 Parameters
259 ----------
260 params : dict
261 Flat dictionary of (key, value) pairs of tunable parameters.
262 is_setup : bool
263 If True, wait for resource being deployed; otherwise, wait for successful deprovisioning.
265 Returns
266 -------
267 result : (Status, dict)
268 A pair of Status and result.
269 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
270 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
271 """
272 params = self._set_default_params(params)
273 _LOG.info("Wait for %s to %s", params.get("deploymentName"),
274 "provision" if is_setup else "deprovision")
275 return self._wait_while(self._check_deployment, Status.PENDING, params)
277 def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]],
278 loop_status: Status, params: dict) -> Tuple[Status, dict]:
279 """
280 Invoke `func` periodically while the status is equal to `loop_status`.
281 Return TIMED_OUT when timing out.
283 Parameters
284 ----------
285 func : a function
286 A function that takes `params` and returns a pair of (Status, {})
287 loop_status: Status
288 Steady state status - keep polling `func` while it returns `loop_status`.
289 params : dict
290 Flat dictionary of (key, value) pairs of tunable parameters.
291 Requires deploymentName.
293 Returns
294 -------
295 result : (Status, dict)
296 A pair of Status and result.
297 """
298 params = self._set_default_params(params)
299 config = merge_parameters(
300 dest=self.config.copy(), source=params, required_keys=["deploymentName"])
302 poll_period = params.get("pollInterval", self._poll_interval)
304 _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s",
305 config["deploymentName"], loop_status, poll_period, self._poll_timeout)
307 ts_timeout = time.time() + self._poll_timeout
308 poll_delay = poll_period
309 while True:
310 # Wait for the suggested time first then check status
311 ts_start = time.time()
312 if ts_start >= ts_timeout:
313 break
315 if poll_delay > 0:
316 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period)
317 time.sleep(poll_delay)
319 (status, output) = func(params)
320 if status != loop_status:
321 return status, output
323 ts_end = time.time()
324 poll_delay = poll_period - ts_end + ts_start
326 _LOG.warning("Request timed out: %s", params)
327 return (Status.TIMED_OUT, {})
329 def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements
330 """
331 Check if Azure deployment exists.
332 Return SUCCEEDED if true, PENDING otherwise.
334 Parameters
335 ----------
336 _params : dict
337 Flat dictionary of (key, value) pairs of tunable parameters.
338 This parameter is not used; we need it for compatibility with
339 other polling functions used in `_wait_while()`.
341 Returns
342 -------
343 result : (Status, dict={})
344 A pair of Status and result. The result is always {}.
345 Status is one of {SUCCEEDED, PENDING, FAILED}
346 """
347 params = self._set_default_params(params)
348 config = merge_parameters(
349 dest=self.config.copy(),
350 source=params,
351 required_keys=[
352 "subscription",
353 "resourceGroup",
354 "deploymentName",
355 ]
356 )
358 _LOG.info("Check deployment: %s", config["deploymentName"])
360 url = self._URL_DEPLOY.format(
361 subscription=config["subscription"],
362 resource_group=config["resourceGroup"],
363 deployment_name=config["deploymentName"],
364 )
366 session = self._get_session(params)
367 try:
368 response = session.get(url, timeout=self._request_timeout)
369 except requests.exceptions.ReadTimeout:
370 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
371 return Status.RUNNING, {}
372 except requests.exceptions.RequestException as ex:
373 _LOG.exception("Error in request checking deployment", exc_info=ex)
374 return (Status.FAILED, {})
376 _LOG.debug("Response: %s", response)
378 if response.status_code == 200:
379 output = response.json()
380 state = output.get("properties", {}).get("provisioningState", "")
382 if state == "Succeeded":
383 return (Status.SUCCEEDED, {})
384 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}:
385 return (Status.PENDING, {})
386 else:
387 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2))
388 return (Status.FAILED, {})
389 elif response.status_code == 404:
390 return (Status.PENDING, {})
392 _LOG.error("Response: %s :: %s", response, response.text)
393 return (Status.FAILED, {})
395 def _provision_resource(self, params: dict) -> Tuple[Status, dict]:
396 """
397 Attempts to (re)deploy a resource.
399 Parameters
400 ----------
401 params : dict
402 Flat dictionary of (key, value) pairs of tunable parameters.
403 Tunables are variable parameters that, together with the
404 Environment configuration, are sufficient to provision the resource.
406 Returns
407 -------
408 result : (Status, dict={})
409 A pair of Status and result. The result is the input `params` plus the
410 parameters extracted from the response JSON, or {} if the status is FAILED.
411 Status is one of {PENDING, SUCCEEDED, FAILED}
412 """
413 if not self._deploy_template:
414 raise ValueError(f"Missing deployment template: {self}")
415 params = self._set_default_params(params)
416 config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"])
417 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
419 params = merge_parameters(dest=self._deploy_params.copy(), source=params)
420 if _LOG.isEnabledFor(logging.DEBUG):
421 _LOG.debug("Deploy: %s merged params ::\n%s",
422 config["deploymentName"], json.dumps(params, indent=2))
424 url = self._URL_DEPLOY.format(
425 subscription=config["subscription"],
426 resource_group=config["resourceGroup"],
427 deployment_name=config["deploymentName"],
428 )
430 json_req = {
431 "properties": {
432 "mode": "Incremental",
433 "template": self._deploy_template,
434 "parameters": {
435 key: {"value": val} for (key, val) in params.items()
436 if key in self._deploy_template.get("parameters", {})
437 }
438 }
439 }
441 if _LOG.isEnabledFor(logging.DEBUG):
442 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
444 response = requests.put(url, json=json_req,
445 headers=self._get_headers(), timeout=self._request_timeout)
447 if _LOG.isEnabledFor(logging.DEBUG):
448 _LOG.debug("Response: %s\n%s", response,
449 json.dumps(response.json(), indent=2)
450 if response.content else "")
451 else:
452 _LOG.info("Response: %s", response)
454 if response.status_code == 200:
455 return (Status.PENDING, config)
456 elif response.status_code == 201:
457 output = self._extract_arm_parameters(response.json())
458 if _LOG.isEnabledFor(logging.DEBUG):
459 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2))
460 params.update(output)
461 params.setdefault("asyncResultsUrl", url)
462 params.setdefault("deploymentName", config["deploymentName"])
463 return (Status.PENDING, params)
464 else:
465 _LOG.error("Response: %s :: %s", response, response.text)
466 # _LOG.error("Bad Request:\n%s", response.request.body)
467 return (Status.FAILED, {})