Source code for autogen_ext.tools.azure._ai_search

from __future__ import annotations

import asyncio
import logging
import time
from abc import ABC, abstractmethod
from contextvars import ContextVar
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Literal,
    Optional,
    Protocol,
    Union,
)

from autogen_core import CancellationToken, Component
from autogen_core.tools import BaseTool, ToolSchema
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.search.documents.aio import SearchClient
from pydantic import BaseModel, Field

from ._config import (
    DEFAULT_API_VERSION,
    AzureAISearchConfig,
)

SearchDocument = Dict[str, Any]
MetadataDict = Dict[str, Any]
ContentDict = Dict[str, Any]

if TYPE_CHECKING:
    from azure.search.documents.aio import AsyncSearchItemPaged

    SearchResultsIterable = AsyncSearchItemPaged[SearchDocument]
else:
    SearchResultsIterable = Any

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from azure.search.documents.models import (
        VectorizableTextQuery,
        VectorizedQuery,
        VectorQuery,
    )

try:
    from azure.search.documents.models import VectorizableTextQuery, VectorizedQuery, VectorQuery

    has_azure_search = True
except ImportError:
    has_azure_search = False
    logger.error(
        "The 'azure-search-documents' package is required for this tool but was not found. "
        "Please install it with: uv add install azure-search-documents"
    )


if TYPE_CHECKING:
    from typing import Protocol

    class SearchClientProtocol(Protocol):
        async def search(self, **kwargs: Any) -> SearchResultsIterable: ...
        async def close(self) -> None: ...
else:
    SearchClientProtocol = Any

__all__ = [
    "AzureAISearchTool",
    "BaseAzureAISearchTool",
    "SearchQuery",
    "SearchResults",
    "SearchResult",
    "VectorizableTextQuery",
    "VectorizedQuery",
    "VectorQuery",
]
logger = logging.getLogger(__name__)


[docs] class SearchQuery(BaseModel): """Search query parameters. This simplified interface only requires a search query string. All other parameters (top, filters, vector fields, etc.) are specified during tool creation rather than at query time, making it easier for language models to generate structured output. Args: query (str): The search query text. """ query: str = Field(description="Search query text")
[docs] class SearchResult(BaseModel): """Search result. Args: score (float): The search score. content (ContentDict): The document content. metadata (MetadataDict): Additional metadata about the document. """ score: float = Field(description="The search score") content: ContentDict = Field(description="The document content") metadata: MetadataDict = Field(description="Additional metadata about the document")
[docs] class SearchResults(BaseModel): """Container for search results. Args: results (List[SearchResult]): List of search results. """ results: List[SearchResult] = Field(description="List of search results")
class EmbeddingProvider(Protocol): """Protocol defining the interface for embedding generation.""" async def _get_embedding(self, query: str) -> List[float]: """Generate embedding vector for the query text.""" ... class EmbeddingProviderMixin: """Mixin class providing embedding generation functionality.""" search_config: AzureAISearchConfig async def _get_embedding(self, query: str) -> List[float]: """Generate embedding vector for the query text.""" if not hasattr(self, "search_config"): raise ValueError("Host class must have a search_config attribute") search_config = self.search_config embedding_provider = getattr(search_config, "embedding_provider", None) embedding_model = getattr(search_config, "embedding_model", None) if not embedding_provider or not embedding_model: raise ValueError( "Client-side embedding is not configured. `embedding_provider` and `embedding_model` must be set." ) from None if embedding_provider.lower() == "azure_openai": try: from azure.identity import DefaultAzureCredential from openai import AsyncAzureOpenAI except ImportError: raise ImportError( "Azure OpenAI SDK is required for client-side embedding generation. " "Please install it with: uv add openai azure-identity" ) from None api_key = getattr(search_config, "openai_api_key", None) api_version = getattr(search_config, "openai_api_version", "2023-11-01") endpoint = getattr(search_config, "openai_endpoint", None) if not endpoint: raise ValueError( "Azure OpenAI endpoint (`openai_endpoint`) must be provided for client-side Azure OpenAI embeddings." ) from None if api_key: azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint) else: def get_token() -> str: credential = DefaultAzureCredential() token = credential.get_token("https://cognitiveservices.azure.com/.default") if not token or not token.token: raise ValueError("Failed to acquire token using DefaultAzureCredential for Azure OpenAI.") return token.token azure_client = AsyncAzureOpenAI( azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint ) try: response = await azure_client.embeddings.create(model=embedding_model, input=query) return response.data[0].embedding except Exception as e: raise ValueError(f"Failed to generate embeddings with Azure OpenAI: {str(e)}") from e elif embedding_provider.lower() == "openai": try: from openai import AsyncOpenAI except ImportError: raise ImportError( "OpenAI SDK is required for client-side embedding generation. " "Please install it with: uv add openai" ) from None api_key = getattr(search_config, "openai_api_key", None) openai_client = AsyncOpenAI(api_key=api_key) try: response = await openai_client.embeddings.create(model=embedding_model, input=query) return response.data[0].embedding except Exception as e: raise ValueError(f"Failed to generate embeddings with OpenAI: {str(e)}") from e else: raise ValueError( f"Unsupported client-side embedding provider: {embedding_provider}. " "Currently supported providers are 'azure_openai' and 'openai'." )
[docs] class BaseAzureAISearchTool( BaseTool[SearchQuery, SearchResults], Component[AzureAISearchConfig], EmbeddingProvider, ABC ): """Abstract base class for Azure AI Search tools. This class defines the common interface and functionality for all Azure AI Search tools. It handles configuration management, client initialization, and the abstract methods that subclasses must implement. Attributes: search_config: Configuration parameters for the search service. Note: This is an abstract base class and should not be instantiated directly. Use concrete implementations or the factory methods in AzureAISearchTool. """ component_config_schema = AzureAISearchConfig component_provider_override = "autogen_ext.tools.azure.BaseAzureAISearchTool" def __init__( self, name: str, endpoint: str, index_name: str, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]], description: Optional[str] = None, api_version: str = DEFAULT_API_VERSION, query_type: Literal["simple", "full", "semantic", "vector"] = "simple", search_fields: Optional[List[str]] = None, select_fields: Optional[List[str]] = None, vector_fields: Optional[List[str]] = None, top: Optional[int] = None, filter: Optional[str] = None, semantic_config_name: Optional[str] = None, enable_caching: bool = False, cache_ttl_seconds: int = 300, embedding_provider: Optional[str] = None, embedding_model: Optional[str] = None, openai_api_key: Optional[str] = None, openai_api_version: Optional[str] = None, openai_endpoint: Optional[str] = None, ): """Initialize the Azure AI Search tool. Args: name (str): The name of this tool instance endpoint (str): The full URL of your Azure AI Search service index_name (str): Name of the search index to query credential (Union[AzureKeyCredential, TokenCredential, Dict[str, str]]): Azure credential for authentication description (Optional[str]): Optional description explaining the tool's purpose api_version (Optional[str]): Azure AI Search API version to use query_type (Literal["simple", "full", "semantic", "vector"]): Type of search to perform search_fields (Optional[List[str]]): Fields to search within documents select_fields (Optional[List[str]]): Fields to return in search results vector_fields (Optional[List[str]]): Fields to use for vector search top (Optional[int]): Maximum number of results to return filter (Optional[str]): OData filter expression to refine search results semantic_config_name (Optional[str]): Semantic configuration name for enhanced results enable_caching (bool): Whether to cache search results cache_ttl_seconds (int): How long to cache results in seconds embedding_provider (Optional[str]): Name of embedding provider for client-side embeddings embedding_model (Optional[str]): Model name for client-side embeddings openai_api_key (Optional[str]): API key for OpenAI/Azure OpenAI embeddings openai_api_version (Optional[str]): API version for Azure OpenAI embeddings openai_endpoint (Optional[str]): Endpoint URL for Azure OpenAI embeddings """ if not has_azure_search: raise ImportError( "Azure Search SDK is required but not installed. " "Please install it with: pip install azure-search-documents>=11.4.0" ) if description is None: description = ( f"Search for information in the {index_name} index using Azure AI Search. " f"Supports full-text search with optional filters and semantic capabilities." ) super().__init__( args_type=SearchQuery, return_type=SearchResults, name=name, description=description, ) processed_credential = self._process_credential(credential) self.search_config: AzureAISearchConfig = AzureAISearchConfig( name=name, description=description, endpoint=endpoint, index_name=index_name, credential=processed_credential, api_version=api_version, query_type=query_type, search_fields=search_fields, select_fields=select_fields, vector_fields=vector_fields, top=top, filter=filter, semantic_config_name=semantic_config_name, enable_caching=enable_caching, cache_ttl_seconds=cache_ttl_seconds, embedding_provider=embedding_provider, embedding_model=embedding_model, openai_api_key=openai_api_key, openai_api_version=openai_api_version, openai_endpoint=openai_endpoint, ) self._endpoint = endpoint self._index_name = index_name self._credential = processed_credential self._api_version = api_version self._client: Optional[SearchClient] = None self._cache: Dict[str, Dict[str, Any]] = {} if self.search_config.api_version == "2023-11-01" and self.search_config.vector_fields: warning_message = ( f"When explicitly setting api_version='{self.search_config.api_version}' for vector search: " f"If client-side embedding is NOT configured (e.g., `embedding_model` is not set), " f"this tool defaults to service-side vectorization (VectorizableTextQuery), which may fail or have limitations with this API version. " f"If client-side embedding IS configured, the tool will use VectorizedQuery, which is generally compatible. " f"For robust vector search, consider omitting api_version (recommended to use SDK default) or use a newer API version." ) logger.warning(warning_message)
[docs] async def close(self) -> None: """Explicitly close the Azure SearchClient if needed (for cleanup).""" if self._client is not None: try: await self._client.close() except Exception: pass finally: self._client = None
def _process_credential( self, credential: Union[AzureKeyCredential, AsyncTokenCredential, Dict[str, str]] ) -> Union[AzureKeyCredential, AsyncTokenCredential]: """Process credential to ensure it's the correct type for async SearchClient. Converts dictionary credentials with 'api_key' to AzureKeyCredential objects. Args: credential: The credential in either object or dictionary form Returns: A properly formatted credential object Raises: ValueError: If the credential dictionary doesn't contain an 'api_key' TypeError: If the credential is not of a supported type """ if isinstance(credential, dict): if "api_key" in credential: return AzureKeyCredential(credential["api_key"]) raise ValueError("If credential is a dict, it must contain an 'api_key' key") if isinstance(credential, (AzureKeyCredential, AsyncTokenCredential)): return credential raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") async def _get_client(self) -> SearchClient: """Get the search client for the configured index. Returns: SearchClient: Initialized search client Raises: ValueError: If index doesn't exist or authentication fails """ if self._client is not None: return self._client try: self._client = SearchClient( endpoint=self.search_config.endpoint, index_name=self.search_config.index_name, credential=self.search_config.credential, api_version=self.search_config.api_version, ) return self._client except ResourceNotFoundError as e: raise ValueError(f"Index '{self.search_config.index_name}' not found in Azure AI Search service.") from e except HttpResponseError as e: if e.status_code == 401: raise ValueError("Authentication failed. Please check your credentials.") from e elif e.status_code == 403: raise ValueError("Permission denied to access this index.") from e else: raise ValueError(f"Error connecting to Azure AI Search: {str(e)}") from e except Exception as e: raise ValueError(f"Unexpected error initializing search client: {str(e)}") from e
[docs] async def run( self, args: Union[str, Dict[str, Any], SearchQuery], cancellation_token: Optional[CancellationToken] = None ) -> SearchResults: """Execute a search against the Azure AI Search index. Args: args: Search query text or SearchQuery object cancellation_token: Optional token to cancel the operation Returns: SearchResults: Container with search results and metadata Raises: ValueError: If the search query is empty or invalid ValueError: If there is an authentication error or other search issue asyncio.CancelledError: If the operation is cancelled """ if isinstance(args, str): if not args.strip(): raise ValueError("Search query cannot be empty") search_query = SearchQuery(query=args) elif isinstance(args, dict) and "query" in args: search_query = SearchQuery(query=args["query"]) elif isinstance(args, SearchQuery): search_query = args else: raise ValueError("Invalid search query format. Expected string, dict with 'query', or SearchQuery") if cancellation_token is not None and cancellation_token.is_cancelled(): raise asyncio.CancelledError("Operation cancelled") cache_key = "" if self.search_config.enable_caching: cache_key_parts = [ search_query.query, str(self.search_config.top), self.search_config.query_type, ",".join(sorted(self.search_config.search_fields or [])), ",".join(sorted(self.search_config.select_fields or [])), ",".join(sorted(self.search_config.vector_fields or [])), str(self.search_config.filter or ""), str(self.search_config.semantic_config_name or ""), ] cache_key = ":".join(filter(None, cache_key_parts)) if cache_key in self._cache: cache_entry = self._cache[cache_key] cache_age = time.time() - cache_entry["timestamp"] if cache_age < self.search_config.cache_ttl_seconds: logger.debug(f"Using cached results for query: {search_query.query}") return SearchResults( results=[ SearchResult(score=r.score, content=r.content, metadata=r.metadata) for r in cache_entry["results"] ] ) try: search_kwargs: Dict[str, Any] = {} if self.search_config.query_type != "vector": search_kwargs["search_text"] = search_query.query search_kwargs["query_type"] = self.search_config.query_type if self.search_config.search_fields: search_kwargs["search_fields"] = self.search_config.search_fields # type: ignore[assignment] if self.search_config.query_type == "semantic" and self.search_config.semantic_config_name: search_kwargs["semantic_configuration_name"] = self.search_config.semantic_config_name if self.search_config.select_fields: search_kwargs["select"] = self.search_config.select_fields # type: ignore[assignment] if self.search_config.filter: search_kwargs["filter"] = str(self.search_config.filter) if self.search_config.top is not None: search_kwargs["top"] = self.search_config.top # type: ignore[assignment] if self.search_config.vector_fields and len(self.search_config.vector_fields) > 0: if not search_query.query: raise ValueError("Query text cannot be empty for vector search operations") use_client_side_embeddings = bool( self.search_config.embedding_model and self.search_config.embedding_provider ) vector_queries: List[Union[VectorizedQuery, VectorizableTextQuery]] = [] if use_client_side_embeddings: from azure.search.documents.models import VectorizedQuery embedding_vector: List[float] = await self._get_embedding(search_query.query) for field_spec in self.search_config.vector_fields: fields = field_spec if isinstance(field_spec, str) else ",".join(field_spec) vector_queries.append( VectorizedQuery( vector=embedding_vector, k_nearest_neighbors=self.search_config.top or 5, fields=fields, kind="vector", ) ) else: from azure.search.documents.models import VectorizableTextQuery for field in self.search_config.vector_fields: fields = field if isinstance(field, str) else ",".join(field) vector_queries.append( VectorizableTextQuery( # type: ignore text=search_query.query, k_nearest_neighbors=self.search_config.top or 5, fields=fields, kind="vectorizable", ) ) search_kwargs["vector_queries"] = vector_queries # type: ignore[assignment] if cancellation_token is not None: dummy_task = asyncio.create_task(asyncio.sleep(60)) cancellation_token.link_future(dummy_task) def is_cancelled() -> bool: return cancellation_token.is_cancelled() else: def is_cancelled() -> bool: return False client = await self._get_client() search_results: SearchResultsIterable = await client.search(**search_kwargs) # type: ignore[arg-type] results: List[SearchResult] = [] async for doc in search_results: if is_cancelled(): raise asyncio.CancelledError("Operation was cancelled") try: metadata: Dict[str, Any] = {} content: Dict[str, Any] = {} for key, value in doc.items(): if isinstance(key, str) and key.startswith(("@", "_")): metadata[key] = value else: content[str(key)] = value score = float(metadata.get("@search.score", 0.0)) results.append(SearchResult(score=score, content=content, metadata=metadata)) except Exception as e: logger.warning(f"Error processing search document: {e}") continue if self.search_config.enable_caching: self._cache[cache_key] = {"results": results, "timestamp": time.time()} return SearchResults(results=results) except asyncio.CancelledError: raise except Exception as e: error_msg = str(e) if isinstance(e, HttpResponseError): if hasattr(e, "message") and e.message: error_msg = e.message if "not found" in error_msg.lower(): raise ValueError(f"Index '{self.search_config.index_name}' not found.") from e elif "unauthorized" in error_msg.lower() or "401" in error_msg: raise ValueError(f"Authentication failed: {error_msg}") from e else: raise ValueError(f"Error from Azure AI Search: {error_msg}") from e
def _to_config(self) -> AzureAISearchConfig: """Convert the current instance to a configuration object.""" return self.search_config @property def schema(self) -> ToolSchema: """Return the schema for the tool.""" return { "name": self.name, "description": self.description, "parameters": { "type": "object", "properties": {"query": {"type": "string", "description": "Search query text"}}, "required": ["query"], "additionalProperties": False, }, "strict": True, }
[docs] def return_value_as_string(self, value: SearchResults) -> str: """Convert the search results to a string representation.""" if not value.results: return "No results found." result_strings: List[str] = [] for i, result in enumerate(value.results, 1): content_items = [f"{k}: {str(v) if v is not None else 'None'}" for k, v in result.content.items()] content_str = ", ".join(content_items) result_strings.append(f"Result {i} (Score: {result.score:.2f}): {content_str}") return "\n".join(result_strings)
@classmethod def _validate_config( cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] ) -> None: """Validate configuration for specific search types.""" credential = config_dict.get("credential") if isinstance(credential, str): raise TypeError("Credential must be AzureKeyCredential, AsyncTokenCredential, or a valid dict") if isinstance(credential, dict) and "api_key" not in credential: raise ValueError("If credential is a dict, it must contain an 'api_key' key") try: _ = AzureAISearchConfig(**config_dict) except Exception as e: raise ValueError(f"Invalid configuration: {str(e)}") from e if search_type == "vector": vector_fields = config_dict.get("vector_fields") if not vector_fields or len(vector_fields) == 0: raise ValueError("vector_fields must contain at least one field name for vector search") elif search_type == "hybrid": vector_fields = config_dict.get("vector_fields") search_fields = config_dict.get("search_fields") if not vector_fields or len(vector_fields) == 0: raise ValueError("vector_fields must contain at least one field name for hybrid search") if not search_fields or len(search_fields) == 0: raise ValueError("search_fields must contain at least one field name for hybrid search") @classmethod @abstractmethod def _from_config(cls, config: AzureAISearchConfig) -> "BaseAzureAISearchTool": """Create a tool instance from a configuration object. This is an abstract method that must be implemented by subclasses. """ if cls is BaseAzureAISearchTool: raise NotImplementedError( "BaseAzureAISearchTool is an abstract base class and cannot be instantiated directly. " "Use a concrete implementation like AzureAISearchTool." ) raise NotImplementedError("Subclasses must implement _from_config") @abstractmethod async def _get_embedding(self, query: str) -> List[float]: """Generate embedding vector for the query text.""" raise NotImplementedError("Subclasses must implement _get_embedding")
_allow_private_constructor = ContextVar("_allow_private_constructor", default=False)
[docs] class AzureAISearchTool(EmbeddingProviderMixin, BaseAzureAISearchTool): """Azure AI Search tool for querying Azure search indexes. This tool provides a simplified interface for querying Azure AI Search indexes using various search methods. It's recommended to use the factory methods to create instances tailored for specific search types: 1. **Full-Text Search**: For traditional keyword-based searches, Lucene queries, or semantically re-ranked results. - Use `AzureAISearchTool.create_full_text_search()` - Supports `query_type`: "simple" (keyword), "full" (Lucene), "semantic". 2. **Vector Search**: For pure similarity searches based on vector embeddings. - Use `AzureAISearchTool.create_vector_search()` 3. **Hybrid Search**: For combining vector search with full-text or semantic search to get the benefits of both. - Use `AzureAISearchTool.create_hybrid_search()` - The text component can be "simple", "full", or "semantic" via the `query_type` parameter. Each factory method configures the tool with appropriate defaults and validations for the chosen search strategy. .. warning:: If you set `query_type="semantic"`, you must also provide a valid `semantic_config_name`. This configuration must be set up in your Azure AI Search index beforehand. """ component_provider_override = "autogen_ext.tools.azure.AzureAISearchTool" @classmethod def _from_config(cls, config: AzureAISearchConfig) -> "AzureAISearchTool": """Create a tool instance from a configuration object. Args: config: The configuration object with tool settings Returns: AzureAISearchTool: An initialized tool instance """ token = _allow_private_constructor.set(True) try: instance = cls( name=config.name, description=config.description or "", endpoint=config.endpoint, index_name=config.index_name, credential=config.credential, api_version=config.api_version, query_type=config.query_type, search_fields=config.search_fields, select_fields=config.select_fields, vector_fields=config.vector_fields, top=config.top, filter=config.filter, semantic_config_name=config.semantic_config_name, enable_caching=config.enable_caching, cache_ttl_seconds=config.cache_ttl_seconds, embedding_provider=config.embedding_provider, embedding_model=config.embedding_model, openai_api_key=config.openai_api_key, openai_api_version=config.openai_api_version, openai_endpoint=config.openai_endpoint, ) return instance finally: _allow_private_constructor.reset(token) @classmethod def _create_from_params( cls, config_dict: Dict[str, Any], search_type: Literal["full_text", "vector", "hybrid"] ) -> "AzureAISearchTool": """Private helper to create an instance from parameters after validation. Args: config_dict: Dictionary with configuration parameters search_type: Type of search for validation Returns: Configured AzureAISearchTool instance """ cls._validate_config(config_dict, search_type) token = _allow_private_constructor.set(True) try: return cls(**config_dict) finally: _allow_private_constructor.reset(token)