Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%
82 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A collection Service functions for configuring SaaS instances on Azure.
7"""
8import logging
10from typing import Any, Callable, Dict, List, Optional, Tuple, Union
12import requests
14from mlos_bench.environments.status import Status
15from mlos_bench.services.base_service import Service
16from mlos_bench.services.types.authenticator_type import SupportsAuth
17from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
18from mlos_bench.util import check_required_params, merge_parameters
20_LOG = logging.getLogger(__name__)
23class AzureSaaSConfigService(Service, SupportsRemoteConfig):
24 """
25 Helper methods to configure Azure Flex services.
26 """
28 _REQUEST_TIMEOUT = 5 # seconds
30 # REST API for Azure SaaS DB Services configuration as described in:
31 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations
32 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations
33 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
35 _URL_CONFIGURE = (
36 "https://management.azure.com" +
37 "/subscriptions/{subscription}" +
38 "/resourceGroups/{resource_group}" +
39 "/providers/{provider}" +
40 "/{server_type}/{vm_name}" +
41 "/{update}" +
42 "?api-version={api_version}"
43 )
45 def __init__(self,
46 config: Optional[Dict[str, Any]] = None,
47 global_config: Optional[Dict[str, Any]] = None,
48 parent: Optional[Service] = None,
49 methods: Union[Dict[str, Callable], List[Callable], None] = None):
50 """
51 Create a new instance of Azure services proxy.
53 Parameters
54 ----------
55 config : dict
56 Free-format dictionary that contains the benchmark environment
57 configuration.
58 global_config : dict
59 Free-format dictionary of global parameters.
60 parent : Service
61 Parent service that can provide mixin functions.
62 methods : Union[Dict[str, Callable], List[Callable], None]
63 New methods to register with the service.
64 """
65 super().__init__(
66 config, global_config, parent,
67 self.merge_methods(methods, [
68 self.configure,
69 self.is_config_pending
70 ])
71 )
73 check_required_params(self.config, {
74 "subscription",
75 "resourceGroup",
76 "provider",
77 })
79 # Provide sane defaults for known DB providers.
80 provider = self.config.get("provider")
81 if provider == "Microsoft.DBforMySQL":
82 self._is_batch = self.config.get("supportsBatchUpdate", True)
83 is_flex = self.config.get("isFlex", True)
84 api_version = self.config.get("apiVersion", "2022-01-01")
85 elif provider == "Microsoft.DBforMariaDB":
86 self._is_batch = self.config.get("supportsBatchUpdate", False)
87 is_flex = self.config.get("isFlex", False)
88 api_version = self.config.get("apiVersion", "2018-06-01")
89 elif provider == "Microsoft.DBforPostgreSQL":
90 self._is_batch = self.config.get("supportsBatchUpdate", False)
91 is_flex = self.config.get("isFlex", True)
92 api_version = self.config.get("apiVersion", "2022-12-01")
93 else:
94 self._is_batch = self.config["supportsBatchUpdate"]
95 is_flex = self.config["isFlex"]
96 api_version = self.config["apiVersion"]
98 self._url_config_set = self._URL_CONFIGURE.format(
99 subscription=self.config["subscription"],
100 resource_group=self.config["resourceGroup"],
101 provider=self.config["provider"],
102 vm_name="{vm_name}",
103 server_type="flexibleServers" if is_flex else "servers",
104 update="updateConfigurations" if self._is_batch else "configurations/{param_name}",
105 api_version=api_version,
106 )
108 self._url_config_get = self._URL_CONFIGURE.format(
109 subscription=self.config["subscription"],
110 resource_group=self.config["resourceGroup"],
111 provider=self.config["provider"],
112 vm_name="{vm_name}",
113 server_type="flexibleServers" if is_flex else "servers",
114 update="configurations",
115 api_version=api_version,
116 )
118 # These parameters can come from command line as strings, so conversion is needed.
119 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
121 def configure(self, config: Dict[str, Any],
122 params: Dict[str, Any]) -> Tuple[Status, dict]:
123 """
124 Update the parameters of an Azure DB service.
126 Parameters
127 ----------
128 config : Dict[str, Any]
129 Key/value pairs of configuration parameters (e.g., vmName).
130 params : Dict[str, Any]
131 Key/value pairs of the service parameters to update.
133 Returns
134 -------
135 result : (Status, dict={})
136 A pair of Status and result. The result is always {}.
137 Status is one of {PENDING, SUCCEEDED, FAILED}
138 """
139 if self._is_batch:
140 return self._config_batch(config, params)
141 return self._config_many(config, params)
143 def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]:
144 """
145 Check if the configuration of an Azure DB service requires a reboot or restart.
147 Parameters
148 ----------
149 config : Dict[str, Any]
150 Key/value pairs of configuration parameters (e.g., vmName).
152 Returns
153 -------
154 result : (Status, dict)
155 A pair of Status and result. A Boolean field
156 "isConfigPendingRestart" indicates whether the service restart is required.
157 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
158 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
159 """
160 config = merge_parameters(
161 dest=self.config.copy(), source=config, required_keys=["vmName"])
162 url = self._url_config_get.format(vm_name=config["vmName"])
163 _LOG.debug("Request: GET %s", url)
164 response = requests.put(
165 url, headers=self._get_headers(), timeout=self._request_timeout)
166 _LOG.debug("Response: %s :: %s", response, response.text)
167 if response.status_code == 504:
168 return (Status.TIMED_OUT, {})
169 if response.status_code != 200:
170 return (Status.FAILED, {})
171 # Currently, Azure Flex servers require a VM reboot.
172 return (Status.SUCCEEDED, {"isConfigPendingReboot": any(
173 {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']]
174 for val in response.json()['value']
175 )})
177 def _get_headers(self) -> dict:
178 """
179 Get the headers for the REST API calls.
180 """
181 assert self._parent is not None and isinstance(self._parent, SupportsAuth), \
182 "Authorization service not provided. Include service-auth.jsonc?"
183 return self._parent.get_auth_headers()
185 def _config_one(self, config: Dict[str, Any],
186 param_name: str, param_value: Any) -> Tuple[Status, dict]:
187 """
188 Update a single parameter of the Azure DB service.
190 Parameters
191 ----------
192 config : Dict[str, Any]
193 Key/value pairs of configuration parameters (e.g., vmName).
194 param_name : str
195 Name of the parameter to update.
196 param_value : Any
197 Value of the parameter to update.
199 Returns
200 -------
201 result : (Status, dict={})
202 A pair of Status and result. The result is always {}.
203 Status is one of {PENDING, SUCCEEDED, FAILED}
204 """
205 config = merge_parameters(
206 dest=self.config.copy(), source=config, required_keys=["vmName"])
207 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
208 _LOG.debug("Request: PUT %s", url)
209 response = requests.put(url, headers=self._get_headers(),
210 json={"properties": {"value": str(param_value)}},
211 timeout=self._request_timeout)
212 _LOG.debug("Response: %s :: %s", response, response.text)
213 if response.status_code == 504:
214 return (Status.TIMED_OUT, {})
215 if response.status_code == 200:
216 return (Status.SUCCEEDED, {})
217 return (Status.FAILED, {})
219 def _config_many(self, config: Dict[str, Any],
220 params: Dict[str, Any]) -> Tuple[Status, dict]:
221 """
222 Update the parameters of an Azure DB service one-by-one.
223 (If batch API is not available for it).
225 Parameters
226 ----------
227 config : Dict[str, Any]
228 Key/value pairs of configuration parameters (e.g., vmName).
229 params : Dict[str, Any]
230 Key/value pairs of the service parameters to update.
232 Returns
233 -------
234 result : (Status, dict={})
235 A pair of Status and result. The result is always {}.
236 Status is one of {PENDING, SUCCEEDED, FAILED}
237 """
238 for (param_name, param_value) in params.items():
239 (status, result) = self._config_one(config, param_name, param_value)
240 if not status.is_succeeded():
241 return (status, result)
242 return (Status.SUCCEEDED, {})
244 def _config_batch(self, config: Dict[str, Any],
245 params: Dict[str, Any]) -> Tuple[Status, dict]:
246 """
247 Batch update the parameters of an Azure DB service.
249 Parameters
250 ----------
251 config : Dict[str, Any]
252 Key/value pairs of configuration parameters (e.g., vmName).
253 params : Dict[str, Any]
254 Key/value pairs of the service parameters to update.
256 Returns
257 -------
258 result : (Status, dict={})
259 A pair of Status and result. The result is always {}.
260 Status is one of {PENDING, SUCCEEDED, FAILED}
261 """
262 config = merge_parameters(
263 dest=self.config.copy(), source=config, required_keys=["vmName"])
264 url = self._url_config_set.format(vm_name=config["vmName"])
265 json_req = {
266 "value": [
267 {"name": key, "properties": {"value": str(val)}}
268 for (key, val) in params.items()
269 ],
270 # "resetAllToDefault": "True"
271 }
272 _LOG.debug("Request: POST %s", url)
273 response = requests.post(url, headers=self._get_headers(),
274 json=json_req, timeout=self._request_timeout)
275 _LOG.debug("Response: %s :: %s", response, response.text)
276 if response.status_code == 504:
277 return (Status.TIMED_OUT, {})
278 if response.status_code == 200:
279 return (Status.SUCCEEDED, {})
280 return (Status.FAILED, {})