Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 53%

49 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6A collection Service functions for managing VMs on Azure. 

7""" 

8 

9import logging 

10from base64 import b64decode 

11from datetime import datetime 

12from typing import Any, Callable, Dict, List, Optional, Union 

13 

14from pytz import UTC 

15 

16import azure.identity as azure_id 

17from azure.keyvault.secrets import SecretClient 

18 

19from mlos_bench.services.base_service import Service 

20from mlos_bench.services.types.authenticator_type import SupportsAuth 

21from mlos_bench.util import check_required_params 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class AzureAuthService(Service, SupportsAuth): 

27 """ 

28 Helper methods to get access to Azure services. 

29 """ 

30 

31 _REQ_INTERVAL = 300 # = 5 min 

32 

33 def __init__(self, 

34 config: Optional[Dict[str, Any]] = None, 

35 global_config: Optional[Dict[str, Any]] = None, 

36 parent: Optional[Service] = None, 

37 methods: Union[Dict[str, Callable], List[Callable], None] = None): 

38 """ 

39 Create a new instance of Azure authentication services proxy. 

40 

41 Parameters 

42 ---------- 

43 config : dict 

44 Free-format dictionary that contains the benchmark environment 

45 configuration. 

46 global_config : dict 

47 Free-format dictionary of global parameters. 

48 parent : Service 

49 Parent service that can provide mixin functions. 

50 methods : Union[Dict[str, Callable], List[Callable], None] 

51 New methods to register with the service. 

52 """ 

53 super().__init__( 

54 config, global_config, parent, 

55 self.merge_methods(methods, [ 

56 self.get_access_token, 

57 self.get_auth_headers, 

58 ]) 

59 ) 

60 

61 # This parameter can come from command line as strings, so conversion is needed. 

62 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL)) 

63 

64 self._access_token = "RENEW *NOW*" 

65 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. 

66 

67 # Login as ourselves 

68 self._cred: Union[azure_id.AzureCliCredential, azure_id.CertificateCredential] 

69 self._cred = azure_id.AzureCliCredential() 

70 

71 # Verify info required for SP auth early 

72 if "spClientId" in self.config: 

73 check_required_params( 

74 self.config, { 

75 "spClientId", 

76 "keyVaultName", 

77 "certName", 

78 "tenant", 

79 } 

80 ) 

81 

82 def _init_sp(self) -> None: 

83 # Perform this initialization outside of __init__ so that environment loading tests 

84 # don't need to specifically mock keyvault interactions out 

85 

86 # Already logged in as SP 

87 if isinstance(self._cred, azure_id.CertificateCredential): 

88 return 

89 

90 sp_client_id = self.config["spClientId"] 

91 keyvault_name = self.config["keyVaultName"] 

92 cert_name = self.config["certName"] 

93 tenant_id = self.config["tenant"] 

94 

95 # Get a client for fetching cert info 

96 keyvault_secrets_client = SecretClient( 

97 vault_url=f"https://{keyvault_name}.vault.azure.net", 

98 credential=self._cred, 

99 ) 

100 

101 # The certificate private key data is stored as hidden "Secret" (not Key strangely) 

102 # in PKCS12 format, but we need to decode it. 

103 secret = keyvault_secrets_client.get_secret(cert_name) 

104 assert secret.value is not None 

105 cert_bytes = b64decode(secret.value) 

106 

107 # Reauthenticate as the service principal. 

108 self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) 

109 

110 def get_access_token(self) -> str: 

111 """ 

112 Get the access token from Azure CLI, if expired. 

113 """ 

114 # Ensure we are logged as the Service Principal, if provided 

115 if "spClientId" in self.config: 

116 self._init_sp() 

117 

118 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() 

119 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) 

120 if ts_diff < self._req_interval: 

121 _LOG.debug("Request new accessToken") 

122 res = self._cred.get_token("https://management.azure.com/.default") 

123 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) 

124 self._access_token = res.token 

125 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) 

126 return self._access_token 

127 

128 def get_auth_headers(self) -> dict: 

129 """ 

130 Get the authorization part of HTTP headers for REST API calls. 

131 """ 

132 return {"Authorization": "Bearer " + self.get_access_token()}