Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 53%
49 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-05 00:36 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6A collection Service functions for managing VMs on Azure.
7"""
9import logging
10from base64 import b64decode
11from datetime import datetime
12from typing import Any, Callable, Dict, List, Optional, Union
14from pytz import UTC
16import azure.identity as azure_id
17from azure.keyvault.secrets import SecretClient
19from mlos_bench.services.base_service import Service
20from mlos_bench.services.types.authenticator_type import SupportsAuth
21from mlos_bench.util import check_required_params
23_LOG = logging.getLogger(__name__)
26class AzureAuthService(Service, SupportsAuth):
27 """
28 Helper methods to get access to Azure services.
29 """
31 _REQ_INTERVAL = 300 # = 5 min
33 def __init__(self,
34 config: Optional[Dict[str, Any]] = None,
35 global_config: Optional[Dict[str, Any]] = None,
36 parent: Optional[Service] = None,
37 methods: Union[Dict[str, Callable], List[Callable], None] = None):
38 """
39 Create a new instance of Azure authentication services proxy.
41 Parameters
42 ----------
43 config : dict
44 Free-format dictionary that contains the benchmark environment
45 configuration.
46 global_config : dict
47 Free-format dictionary of global parameters.
48 parent : Service
49 Parent service that can provide mixin functions.
50 methods : Union[Dict[str, Callable], List[Callable], None]
51 New methods to register with the service.
52 """
53 super().__init__(
54 config, global_config, parent,
55 self.merge_methods(methods, [
56 self.get_access_token,
57 self.get_auth_headers,
58 ])
59 )
61 # This parameter can come from command line as strings, so conversion is needed.
62 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL))
64 self._access_token = "RENEW *NOW*"
65 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.
67 # Login as ourselves
68 self._cred: Union[azure_id.AzureCliCredential, azure_id.CertificateCredential]
69 self._cred = azure_id.AzureCliCredential()
71 # Verify info required for SP auth early
72 if "spClientId" in self.config:
73 check_required_params(
74 self.config, {
75 "spClientId",
76 "keyVaultName",
77 "certName",
78 "tenant",
79 }
80 )
82 def _init_sp(self) -> None:
83 # Perform this initialization outside of __init__ so that environment loading tests
84 # don't need to specifically mock keyvault interactions out
86 # Already logged in as SP
87 if isinstance(self._cred, azure_id.CertificateCredential):
88 return
90 sp_client_id = self.config["spClientId"]
91 keyvault_name = self.config["keyVaultName"]
92 cert_name = self.config["certName"]
93 tenant_id = self.config["tenant"]
95 # Get a client for fetching cert info
96 keyvault_secrets_client = SecretClient(
97 vault_url=f"https://{keyvault_name}.vault.azure.net",
98 credential=self._cred,
99 )
101 # The certificate private key data is stored as hidden "Secret" (not Key strangely)
102 # in PKCS12 format, but we need to decode it.
103 secret = keyvault_secrets_client.get_secret(cert_name)
104 assert secret.value is not None
105 cert_bytes = b64decode(secret.value)
107 # Reauthenticate as the service principal.
108 self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes)
110 def get_access_token(self) -> str:
111 """
112 Get the access token from Azure CLI, if expired.
113 """
114 # Ensure we are logged as the Service Principal, if provided
115 if "spClientId" in self.config:
116 self._init_sp()
118 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
119 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
120 if ts_diff < self._req_interval:
121 _LOG.debug("Request new accessToken")
122 res = self._cred.get_token("https://management.azure.com/.default")
123 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
124 self._access_token = res.token
125 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)
126 return self._access_token
128 def get_auth_headers(self) -> dict:
129 """
130 Get the authorization part of HTTP headers for REST API calls.
131 """
132 return {"Authorization": "Bearer " + self.get_access_token()}