Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%
81 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection Service functions for configuring SaaS instances on Azure."""
6import logging
7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9import requests
11from mlos_bench.environments.status import Status
12from mlos_bench.services.base_service import Service
13from mlos_bench.services.types.authenticator_type import SupportsAuth
14from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
15from mlos_bench.util import check_required_params, merge_parameters
17_LOG = logging.getLogger(__name__)
20class AzureSaaSConfigService(Service, SupportsRemoteConfig):
21 """Helper methods to configure Azure Flex services."""
23 _REQUEST_TIMEOUT = 5 # seconds
25 # REST API for Azure SaaS DB Services configuration as described in:
26 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations
27 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations
28 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
30 _URL_CONFIGURE = (
31 "https://management.azure.com"
32 "/subscriptions/{subscription}"
33 "/resourceGroups/{resource_group}"
34 "/providers/{provider}"
35 "/{server_type}/{vm_name}"
36 "/{update}"
37 "?api-version={api_version}"
38 )
40 def __init__(
41 self,
42 config: Optional[Dict[str, Any]] = None,
43 global_config: Optional[Dict[str, Any]] = None,
44 parent: Optional[Service] = None,
45 methods: Union[Dict[str, Callable], List[Callable], None] = None,
46 ):
47 """
48 Create a new instance of Azure services proxy.
50 Parameters
51 ----------
52 config : dict
53 Free-format dictionary that contains the benchmark environment
54 configuration.
55 global_config : dict
56 Free-format dictionary of global parameters.
57 parent : Service
58 Parent service that can provide mixin functions.
59 methods : Union[Dict[str, Callable], List[Callable], None]
60 New methods to register with the service.
61 """
62 super().__init__(
63 config,
64 global_config,
65 parent,
66 self.merge_methods(methods, [self.configure, self.is_config_pending]),
67 )
69 check_required_params(
70 self.config,
71 {
72 "subscription",
73 "resourceGroup",
74 "provider",
75 },
76 )
78 # Provide sane defaults for known DB providers.
79 provider = self.config.get("provider")
80 if provider == "Microsoft.DBforMySQL":
81 self._is_batch = self.config.get("supportsBatchUpdate", True)
82 is_flex = self.config.get("isFlex", True)
83 api_version = self.config.get("apiVersion", "2022-01-01")
84 elif provider == "Microsoft.DBforMariaDB":
85 self._is_batch = self.config.get("supportsBatchUpdate", False)
86 is_flex = self.config.get("isFlex", False)
87 api_version = self.config.get("apiVersion", "2018-06-01")
88 elif provider == "Microsoft.DBforPostgreSQL":
89 self._is_batch = self.config.get("supportsBatchUpdate", False)
90 is_flex = self.config.get("isFlex", True)
91 api_version = self.config.get("apiVersion", "2022-12-01")
92 else:
93 self._is_batch = self.config["supportsBatchUpdate"]
94 is_flex = self.config["isFlex"]
95 api_version = self.config["apiVersion"]
97 self._url_config_set = self._URL_CONFIGURE.format(
98 subscription=self.config["subscription"],
99 resource_group=self.config["resourceGroup"],
100 provider=self.config["provider"],
101 vm_name="{vm_name}",
102 server_type="flexibleServers" if is_flex else "servers",
103 update="updateConfigurations" if self._is_batch else "configurations/{param_name}",
104 api_version=api_version,
105 )
107 self._url_config_get = self._URL_CONFIGURE.format(
108 subscription=self.config["subscription"],
109 resource_group=self.config["resourceGroup"],
110 provider=self.config["provider"],
111 vm_name="{vm_name}",
112 server_type="flexibleServers" if is_flex else "servers",
113 update="configurations",
114 api_version=api_version,
115 )
117 # These parameters can come from command line as strings, so conversion is needed.
118 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
120 def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
121 """
122 Update the parameters of an Azure DB service.
124 Parameters
125 ----------
126 config : Dict[str, Any]
127 Key/value pairs of configuration parameters (e.g., vmName).
128 params : Dict[str, Any]
129 Key/value pairs of the service parameters to update.
131 Returns
132 -------
133 result : (Status, dict)
134 A pair of Status and result. The result is always {}.
135 Status is one of {PENDING, SUCCEEDED, FAILED}
136 """
137 if self._is_batch:
138 return self._config_batch(config, params)
139 return self._config_many(config, params)
141 def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]:
142 """
143 Check if the configuration of an Azure DB service requires a reboot or restart.
145 Parameters
146 ----------
147 config : Dict[str, Any]
148 Key/value pairs of configuration parameters (e.g., vmName).
150 Returns
151 -------
152 result : (Status, dict)
153 A pair of Status and result. A Boolean field
154 "isConfigPendingRestart" indicates whether the service restart is required.
155 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
156 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
157 """
158 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
159 url = self._url_config_get.format(vm_name=config["vmName"])
160 _LOG.debug("Request: GET %s", url)
161 response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout)
162 _LOG.debug("Response: %s :: %s", response, response.text)
163 if response.status_code == 504:
164 return (Status.TIMED_OUT, {})
165 if response.status_code != 200:
166 return (Status.FAILED, {})
167 # Currently, Azure Flex servers require a VM reboot.
168 return (
169 Status.SUCCEEDED,
170 {
171 "isConfigPendingReboot": any(
172 {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]]
173 for val in response.json()["value"]
174 )
175 },
176 )
178 def _get_headers(self) -> dict:
179 """Get the headers for the REST API calls."""
180 assert self._parent is not None and isinstance(
181 self._parent, SupportsAuth
182 ), "Authorization service not provided. Include service-auth.jsonc?"
183 return self._parent.get_auth_headers()
185 def _config_one(
186 self,
187 config: Dict[str, Any],
188 param_name: str,
189 param_value: Any,
190 ) -> Tuple[Status, dict]:
191 """
192 Update a single parameter of the Azure DB service.
194 Parameters
195 ----------
196 config : Dict[str, Any]
197 Key/value pairs of configuration parameters (e.g., vmName).
198 param_name : str
199 Name of the parameter to update.
200 param_value : Any
201 Value of the parameter to update.
203 Returns
204 -------
205 result : (Status, dict)
206 A pair of Status and result. The result is always {}.
207 Status is one of {PENDING, SUCCEEDED, FAILED}
208 """
209 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
210 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
211 _LOG.debug("Request: PUT %s", url)
212 response = requests.put(
213 url,
214 headers=self._get_headers(),
215 json={"properties": {"value": str(param_value)}},
216 timeout=self._request_timeout,
217 )
218 _LOG.debug("Response: %s :: %s", response, response.text)
219 if response.status_code == 504:
220 return (Status.TIMED_OUT, {})
221 if response.status_code == 200:
222 return (Status.SUCCEEDED, {})
223 return (Status.FAILED, {})
225 def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
226 """
227 Update the parameters of an Azure DB service one-by-one. (If batch API is not
228 available for it).
230 Parameters
231 ----------
232 config : Dict[str, Any]
233 Key/value pairs of configuration parameters (e.g., vmName).
234 params : Dict[str, Any]
235 Key/value pairs of the service parameters to update.
237 Returns
238 -------
239 result : (Status, dict)
240 A pair of Status and result. The result is always {}.
241 Status is one of {PENDING, SUCCEEDED, FAILED}
242 """
243 for param_name, param_value in params.items():
244 (status, result) = self._config_one(config, param_name, param_value)
245 if not status.is_succeeded():
246 return (status, result)
247 return (Status.SUCCEEDED, {})
249 def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]:
250 """
251 Batch update the parameters of an Azure DB service.
253 Parameters
254 ----------
255 config : Dict[str, Any]
256 Key/value pairs of configuration parameters (e.g., vmName).
257 params : Dict[str, Any]
258 Key/value pairs of the service parameters to update.
260 Returns
261 -------
262 result : (Status, dict)
263 A pair of Status and result. The result is always {}.
264 Status is one of {PENDING, SUCCEEDED, FAILED}
265 """
266 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
267 url = self._url_config_set.format(vm_name=config["vmName"])
268 json_req = {
269 "value": [
270 {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items()
271 ],
272 # "resetAllToDefault": "True"
273 }
274 _LOG.debug("Request: POST %s", url)
275 response = requests.post(
276 url,
277 headers=self._get_headers(),
278 json=json_req,
279 timeout=self._request_timeout,
280 )
281 _LOG.debug("Response: %s :: %s", response, response.text)
282 if response.status_code == 504:
283 return (Status.TIMED_OUT, {})
284 if response.status_code == 200:
285 return (Status.SUCCEEDED, {})
286 return (Status.FAILED, {})