# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from olive.common.config_utils import ConfigBase
from olive.common.pydantic_v1 import Field, validator
from olive.common.utils import get_credentials
logger = logging.getLogger(__name__)
[docs]class AzureMLClientConfig(ConfigBase):
"""Configuration for AzureMLClient.
This class is used to create an MLClient instance for AzureML operations.
Some fields like `read_timeout`, `max_operation_retries`, `operation_retry_interval` are used to control the
behavior of azureml operations like resource creation or download.
"""
subscription_id: str = Field(
None, description="Azure subscription id. Required if aml_config_path is not provided."
)
resource_group: str = Field(None, description="Azure resource group. Required if aml_config_path is not provided.")
workspace_name: str = Field(None, description="Azure workspace name. Required if aml_config_path is not provided.")
aml_config_path: str = Field(
None, description="Path to AzureML config file. If provided, other fields are ignored."
)
# read timeout in seconds for HTTP requests, user can increase if they find the default value too small.
# The default value from azureml sdk is 3000 which is too large and cause the evaluations and pass runs to
# sometimes hang for a long time between retries of job stream and download steps.
read_timeout: int = Field(60, description="Read timeout in seconds for HTTP requests.")
max_operation_retries: int = Field(
3, description="Max number of retries for AzureML operations like resource creation or download."
)
operation_retry_interval: int = Field(
5,
description=(
"Initial interval in seconds between retries for AzureML operations like resource creation or download. The"
" interval doubles after each retry."
),
)
# as the DefaultAzureCredential is used by default, we need to provide the default auth config for it.
# but DefaultAzureCredential accept kwargs as parameters, it is hard to validate the config.
# so we just provide a dict here and let the user to provide the correct config following the doc.
default_auth_params: Optional[Dict[str, Any]] = Field(
None,
description=(
"Default auth config for AzureML client. Please refer to"
" https://learn.microsoft.com/en-us/python/api/azure-identity/"
"azure.identity.defaultazurecredential?view=azure-python#parameters"
" for more details."
),
)
keyvault_name: Optional[str] = Field(
None,
description="Name of the keyvault to use. If provided, the keyvault will be used to retrieve secrets.",
)
@validator("aml_config_path", always=True)
def validate_aml_config_path(cls, v, values):
if v is not None:
if not Path(v).exists():
raise ValueError(f"aml_config_path {v} does not exist")
if not Path(v).is_file():
raise ValueError(f"aml_config_path {v} is not a file")
return v
[docs] def get_workspace_config(self) -> Dict[str, str]:
"""Get the workspace config as a dict."""
if self.aml_config_path:
# If aml_config_path is provided, load the config from the file.
with open(self.aml_config_path) as f:
return json.load(f)
else:
# If aml_config_path is not provided, return the config from the class.
return {
"subscription_id": self.subscription_id,
"resource_group": self.resource_group,
"workspace_name": self.workspace_name,
}
[docs] def create_client(self):
"""Create an MLClient instance."""
from azure.ai.ml import MLClient
set_azure_logging_if_noset()
if self.aml_config_path is None:
if self.subscription_id is None:
raise ValueError("subscription_id must be provided if aml_config_path is not provided")
if self.resource_group is None:
raise ValueError("resource_group must be provided if aml_config_path is not provided")
if self.workspace_name is None:
raise ValueError("workspace_name must be provided if aml_config_path is not provided")
return MLClient(
credential=get_credentials(self.default_auth_params),
subscription_id=self.subscription_id,
resource_group_name=self.resource_group,
workspace_name=self.workspace_name,
read_timeout=self.read_timeout,
)
else:
return MLClient.from_config(
credential=get_credentials(self.default_auth_params),
path=self.aml_config_path,
read_timeout=self.read_timeout,
)
[docs] def create_registry_client(self, registry_name: str):
"""Create an MLClient instance."""
from azure.ai.ml import MLClient
set_azure_logging_if_noset()
return MLClient(credential=get_credentials(self.default_auth_params), registry_name=registry_name)
def set_azure_logging_if_noset():
# set logger level to error to avoid too many logs from azure sdk
azure_ml_logger = logging.getLogger("azure.ai.ml")
# only set the level if it is not set, to avoid changing the level set by the user
if not azure_ml_logger.level:
azure_ml_logger.setLevel(logging.ERROR)
azure_identity_logger = logging.getLogger("azure.identity")
# only set the level if it is not set, to avoid changing the level set by the user
if not azure_identity_logger.level:
azure_identity_logger.setLevel(logging.ERROR)