Source code for autogen_ext.models.azure._azure_ai_client
importasyncioimportloggingimportrefromasyncioimportTaskfrominspectimportgetfullargspecfromtypingimportAny,Dict,List,Mapping,Optional,Sequence,castfromautogen_coreimportEVENT_LOGGER_NAME,CancellationToken,FunctionCall,Imagefromautogen_core.loggingimportLLMCallEvent,LLMStreamEndEvent,LLMStreamStartEventfromautogen_core.modelsimport(AssistantMessage,ChatCompletionClient,CreateResult,FinishReasons,FunctionExecutionResultMessage,LLMMessage,ModelFamily,ModelInfo,RequestUsage,SystemMessage,UserMessage,validate_model_info,)fromautogen_core.toolsimportTool,ToolSchemafromazure.ai.inference.aioimportChatCompletionsClientfromazure.ai.inference.modelsimport(AssistantMessageasAzureAssistantMessage,)fromazure.ai.inference.modelsimport(ChatCompletions,ChatCompletionsToolCall,ChatCompletionsToolDefinition,CompletionsFinishReason,ContentItem,FunctionDefinition,ImageContentItem,ImageDetailLevel,ImageUrl,StreamingChatChoiceUpdate,StreamingChatCompletionsUpdate,TextContentItem,)fromazure.ai.inference.modelsimport(FunctionCallasAzureFunctionCall,)fromazure.ai.inference.modelsimport(SystemMessageasAzureSystemMessage,)fromazure.ai.inference.modelsimport(ToolMessageasAzureToolMessage,)fromazure.ai.inference.modelsimport(UserMessageasAzureUserMessage,)frompydanticimportBaseModelfromtyping_extensionsimportAsyncGenerator,Union,Unpackfromautogen_ext.models.azure.configimport(GITHUB_MODELS_ENDPOINT,AzureAIChatCompletionClientConfig,)from.._utils.parse_r1_contentimportparse_r1_contentcreate_kwargs=set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs)AzureMessage=Union[AzureSystemMessage,AzureUserMessage,AzureAssistantMessage,AzureToolMessage]logger=logging.getLogger(EVENT_LOGGER_NAME)def_is_github_model(endpoint:str)->bool:returnendpoint==GITHUB_MODELS_ENDPOINTdefconvert_tools(tools:Sequence[Tool|ToolSchema])->List[ChatCompletionsToolDefinition]:result:List[ChatCompletionsToolDefinition]=[]fortoolintools:ifisinstance(tool,Tool):tool_schema=tool.schema.copy()else:assertisinstance(tool,dict)tool_schema=tool.copy()if"parameters"intool_schema:forvalueintool_schema["parameters"]["properties"].values():if"title"invalue.keys():delvalue["title"]function_def:Dict[str,Any]=dict(name=tool_schema["name"])if"description"intool_schema:function_def["description"]=tool_schema["description"]if"parameters"intool_schema:function_def["parameters"]=tool_schema["parameters"]result.append(ChatCompletionsToolDefinition(function=FunctionDefinition(**function_def),),)returnresultdef_func_call_to_azure(message:FunctionCall)->ChatCompletionsToolCall:returnChatCompletionsToolCall(id=message.id,function=AzureFunctionCall(arguments=message.arguments,name=message.name),)def_system_message_to_azure(message:SystemMessage)->AzureSystemMessage:returnAzureSystemMessage(content=message.content)def_user_message_to_azure(message:UserMessage)->AzureUserMessage:assert_valid_name(message.source)ifisinstance(message.content,str):returnAzureUserMessage(content=message.content)else:parts:List[ContentItem]=[]forpartinmessage.content:ifisinstance(part,str):parts.append(TextContentItem(text=part))elifisinstance(part,Image):# TODO: support url based images# TODO: support specifying detailsparts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri,detail=ImageDetailLevel.AUTO)))else:raiseValueError(f"Unknown content type: {message.content}")returnAzureUserMessage(content=parts)def_assistant_message_to_azure(message:AssistantMessage)->AzureAssistantMessage:assert_valid_name(message.source)ifisinstance(message.content,list):returnAzureAssistantMessage(tool_calls=[_func_call_to_azure(x)forxinmessage.content],)else:returnAzureAssistantMessage(content=message.content)def_tool_message_to_azure(message:FunctionExecutionResultMessage)->Sequence[AzureToolMessage]:return[AzureToolMessage(content=x.content,tool_call_id=x.call_id)forxinmessage.content]defto_azure_message(message:LLMMessage)->Sequence[AzureMessage]:ifisinstance(message,SystemMessage):return[_system_message_to_azure(message)]elifisinstance(message,UserMessage):return[_user_message_to_azure(message)]elifisinstance(message,AssistantMessage):return[_assistant_message_to_azure(message)]else:return_tool_message_to_azure(message)defnormalize_name(name:str)->str:""" LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". Prefer _assert_valid_name for validating user configuration or input """returnre.sub(r"[^a-zA-Z0-9_-]","_",name)[:64]defassert_valid_name(name:str)->str:""" Ensure that configured names are valid, raises ValueError if not. For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. """ifnotre.match(r"^[a-zA-Z0-9_-]+$",name):raiseValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")iflen(name)>64:raiseValueError(f"Invalid name: {name}. Name must be less than 64 characters.")returnname
[docs]classAzureAIChatCompletionClient(ChatCompletionClient):""" Chat completion client for models hosted on Azure AI Foundry or GitHub Models. See `here <https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions>`_ for more info. Args: endpoint (str): The endpoint to use. **Required.** credential (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required** model_info (ModelInfo): The model family and capabilities of the model. **Required.** model (str): The name of the model. **Required if model is hosted on GitHub Models.** frequency_penalty: (optional,float) presence_penalty: (optional,float) temperature: (optional,float) top_p: (optional,float) max_tokens: (optional,int) response_format: (optional, literal["text", "json_object"]) stop: (optional,List[str]) tools: (optional,List[ChatCompletionsToolDefinition]) tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]) seed: (optional,int) model_extras: (optional,Dict[str, Any]) To use this client, you must install the `azure` extra: .. code-block:: bash pip install "autogen-ext[azure]" The following code snippet shows how to use the client with GitHub Models: .. code-block:: python import asyncio import os from azure.core.credentials import AzureKeyCredential from autogen_ext.models.azure import AzureAIChatCompletionClient from autogen_core.models import UserMessage async def main(): client = AzureAIChatCompletionClient( model="Phi-4", endpoint="https://models.inference.ai.azure.com", # To authenticate with the model you will need to generate a personal access token (PAT) in your GitHub settings. # Create your PAT token by following instructions here: https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens credential=AzureKeyCredential(os.environ["GITHUB_TOKEN"]), model_info={ "json_output": False, "function_calling": False, "vision": False, "family": "unknown", "structured_output": False, }, ) result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) print(result) # Close the client. await client.close() if __name__ == "__main__": asyncio.run(main()) To use streaming, you can use the `create_stream` method: .. code-block:: python import asyncio import os from autogen_core.models import UserMessage from autogen_ext.models.azure import AzureAIChatCompletionClient from azure.core.credentials import AzureKeyCredential async def main(): client = AzureAIChatCompletionClient( model="Phi-4", endpoint="https://models.inference.ai.azure.com", # To authenticate with the model you will need to generate a personal access token (PAT) in your GitHub settings. # Create your PAT token by following instructions here: https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens credential=AzureKeyCredential(os.environ["GITHUB_TOKEN"]), model_info={ "json_output": False, "function_calling": False, "vision": False, "family": "unknown", "structured_output": False, }, ) # Create a stream. stream = client.create_stream([UserMessage(content="Write a poem about the ocean", source="user")]) async for chunk in stream: print(chunk, end="", flush=True) print() # Close the client. await client.close() if __name__ == "__main__": asyncio.run(main()) """def__init__(self,**kwargs:Unpack[AzureAIChatCompletionClientConfig]):config=self._validate_config(kwargs)# type: ignoreself._model_info=config["model_info"]# type: ignoreself._client=self._create_client(config)self._create_args=self._prepare_create_args(config)self._actual_usage=RequestUsage(prompt_tokens=0,completion_tokens=0)self._total_usage=RequestUsage(prompt_tokens=0,completion_tokens=0)@staticmethoddef_validate_config(config:Dict[str,Any])->AzureAIChatCompletionClientConfig:if"endpoint"notinconfig:raiseValueError("endpoint is required for AzureAIChatCompletionClient")if"credential"notinconfig:raiseValueError("credential is required for AzureAIChatCompletionClient")if"model_info"notinconfig:raiseValueError("model_info is required for AzureAIChatCompletionClient")validate_model_info(config["model_info"])if_is_github_model(config["endpoint"])and"model"notinconfig:raiseValueError("model is required for when using a Github model with AzureAIChatCompletionClient")returncast(AzureAIChatCompletionClientConfig,config)@staticmethoddef_create_client(config:AzureAIChatCompletionClientConfig)->ChatCompletionsClient:returnChatCompletionsClient(**config)@staticmethoddef_prepare_create_args(config:Mapping[str,Any])->Dict[str,Any]:create_args={k:vfork,vinconfig.items()ifkincreate_kwargs}returncreate_args
def_validate_model_info(self,messages:Sequence[LLMMessage],tools:Sequence[Tool|ToolSchema],json_output:Optional[bool|type[BaseModel]],create_args:Dict[str,Any],)->None:ifself.model_info["vision"]isFalse:formessageinmessages:ifisinstance(message,UserMessage):ifisinstance(message.content,list)andany(isinstance(x,Image)forxinmessage.content):raiseValueError("Model does not support vision and image was provided")ifjson_outputisnotNone:ifself.model_info["json_output"]isFalseandjson_outputisTrue:raiseValueError("Model does not support JSON output")ifisinstance(json_output,type):# TODO: we should support this in the future.raiseValueError("Structured output is not currently supported for AzureAIChatCompletionClient")ifjson_outputisTrueand"response_format"notincreate_args:create_args["response_format"]="json_object"ifself.model_info["json_output"]isFalseandjson_outputisTrue:raiseValueError("Model does not support JSON output")ifself.model_info["function_calling"]isFalseandlen(tools)>0:raiseValueError("Model does not support function calling")
[docs]asyncdefcreate(self,messages:Sequence[LLMMessage],*,tools:Sequence[Tool|ToolSchema]=[],json_output:Optional[bool|type[BaseModel]]=None,extra_create_args:Mapping[str,Any]={},cancellation_token:Optional[CancellationToken]=None,)->CreateResult:extra_create_args_keys=set(extra_create_args.keys())ifnotcreate_kwargs.issuperset(extra_create_args_keys):raiseValueError(f"Extra create args are invalid: {extra_create_args_keys-create_kwargs}")# Copy the create args and overwrite anything in extra_create_argscreate_args=self._create_args.copy()create_args.update(extra_create_args)self._validate_model_info(messages,tools,json_output,create_args)azure_messages_nested=[to_azure_message(msg)formsginmessages]azure_messages=[itemforsublistinazure_messages_nestedforiteminsublist]task:Task[ChatCompletions]iflen(tools)>0:converted_tools=convert_tools(tools)task=asyncio.create_task(# type: ignoreself._client.complete(messages=azure_messages,tools=converted_tools,**create_args)# type: ignore)else:task=asyncio.create_task(# type: ignoreself._client.complete(# type: ignoremessages=azure_messages,**create_args,))ifcancellation_tokenisnotNone:cancellation_token.link_future(task)result:ChatCompletions=awaittaskusage=RequestUsage(prompt_tokens=result.usage.prompt_tokensifresult.usageelse0,completion_tokens=result.usage.completion_tokensifresult.usageelse0,)logger.info(LLMCallEvent(messages=[m.as_dict()forminazure_messages],response=result.as_dict(),prompt_tokens=usage.prompt_tokens,completion_tokens=usage.completion_tokens,))choice=result.choices[0]ifchoice.finish_reason==CompletionsFinishReason.TOOL_CALLS:assertchoice.message.tool_callsisnotNonecontent:Union[str,List[FunctionCall]]=[FunctionCall(id=x.id,arguments=x.function.arguments,name=normalize_name(x.function.name),)forxinchoice.message.tool_calls]finish_reason="function_calls"else:ifisinstance(choice.finish_reason,CompletionsFinishReason):finish_reason=choice.finish_reason.valueelse:finish_reason=choice.finish_reason# type: ignorecontent=choice.message.contentor""ifisinstance(content,str)andself._model_info["family"]==ModelFamily.R1:thought,content=parse_r1_content(content)else:thought=Noneresponse=CreateResult(finish_reason=finish_reason,# type: ignorecontent=content,usage=usage,cached=False,thought=thought,)self.add_usage(usage)returnresponse
[docs]asyncdefcreate_stream(self,messages:Sequence[LLMMessage],*,tools:Sequence[Tool|ToolSchema]=[],json_output:Optional[bool|type[BaseModel]]=None,extra_create_args:Mapping[str,Any]={},cancellation_token:Optional[CancellationToken]=None,)->AsyncGenerator[Union[str,CreateResult],None]:extra_create_args_keys=set(extra_create_args.keys())ifnotcreate_kwargs.issuperset(extra_create_args_keys):raiseValueError(f"Extra create args are invalid: {extra_create_args_keys-create_kwargs}")create_args:Dict[str,Any]=self._create_args.copy()create_args.update(extra_create_args)self._validate_model_info(messages,tools,json_output,create_args)# azure_messages = [to_azure_message(m) for m in messages]azure_messages_nested=[to_azure_message(msg)formsginmessages]azure_messages=[itemforsublistinazure_messages_nestedforiteminsublist]iflen(tools)>0:converted_tools=convert_tools(tools)task=asyncio.create_task(self._client.complete(messages=azure_messages,tools=converted_tools,stream=True,**create_args))else:task=asyncio.create_task(self._client.complete(messages=azure_messages,stream=True,**create_args))ifcancellation_tokenisnotNone:cancellation_token.link_future(task)# result: ChatCompletions = await taskfinish_reason:Optional[FinishReasons]=Nonecontent_deltas:List[str]=[]full_tool_calls:Dict[str,FunctionCall]={}prompt_tokens=0completion_tokens=0chunk:Optional[StreamingChatCompletionsUpdate]=Nonechoice:Optional[StreamingChatChoiceUpdate]=Nonefirst_chunk=Trueasyncforchunkinawaittask:# type: ignoreiffirst_chunk:first_chunk=False# Emit the start event.logger.info(LLMStreamStartEvent(messages=[m.as_dict()forminazure_messages],))assertisinstance(chunk,StreamingChatCompletionsUpdate)choice=chunk.choices[0]iflen(chunk.choices)>0elseNoneifchoiceandchoice.finish_reasonisnotNone:ifisinstance(choice.finish_reason,CompletionsFinishReason):finish_reason=cast(FinishReasons,choice.finish_reason.value)else:ifchoice.finish_reasonin["stop","length","function_calls","content_filter","unknown"]:finish_reason=choice.finish_reason# type: ignoreelse:raiseValueError(f"Unexpected finish reason: {choice.finish_reason}")# We first try to load the contentifchoiceandchoice.delta.contentisnotNone:content_deltas.append(choice.delta.content)yieldchoice.delta.content# Otherwise, we try to load the tool callsifchoiceandchoice.delta.tool_callsisnotNone:fortool_call_chunkinchoice.delta.tool_calls:# print(tool_call_chunk)if"index"intool_call_chunk:idx=tool_call_chunk["index"]else:idx=tool_call_chunk.idifidxnotinfull_tool_calls:full_tool_calls[idx]=FunctionCall(id="",arguments="",name="")full_tool_calls[idx].id+=tool_call_chunk.idfull_tool_calls[idx].name+=tool_call_chunk.function.namefull_tool_calls[idx].arguments+=tool_call_chunk.function.argumentsifchunkandchunk.usage:prompt_tokens=chunk.usage.prompt_tokensiffinish_reasonisNone:raiseValueError("No stop reason found")ifchoiceandchoice.finish_reasonisCompletionsFinishReason.TOOL_CALLS:finish_reason="function_calls"content:Union[str,List[FunctionCall]]iflen(content_deltas)>1:content="".join(content_deltas)ifchunkandchunk.usage:completion_tokens=chunk.usage.completion_tokenselse:completion_tokens=0else:content=list(full_tool_calls.values())usage=RequestUsage(completion_tokens=completion_tokens,prompt_tokens=prompt_tokens,)ifisinstance(content,str)andself._model_info["family"]==ModelFamily.R1:thought,content=parse_r1_content(content)else:thought=Noneresult=CreateResult(finish_reason=finish_reason,content=content,usage=usage,cached=False,thought=thought,)# Log the end of the stream.logger.info(LLMStreamEndEvent(response=result.model_dump(),prompt_tokens=usage.prompt_tokens,completion_tokens=usage.completion_tokens,))self.add_usage(usage)yieldresult
@propertydefmodel_info(self)->ModelInfo:returnself._model_info@propertydefcapabilities(self)->ModelInfo:returnself.model_infodef__del__(self)->None:# TODO: This is a hack to close the open clientifhasattr(self,"_client"):try:asyncio.get_running_loop().create_task(self._client.close())exceptRuntimeError:asyncio.run(self._client.close())