Source code for autogen_core.tool_agent._caller_loop

import asyncio
from typing import List

from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall
from ..models import (
    AssistantMessage,
    ChatCompletionClient,
    FunctionExecutionResult,
    FunctionExecutionResultMessage,
    LLMMessage,
)
from ..tools import Tool, ToolSchema
from ._tool_agent import ToolException


[docs] async def tool_agent_caller_loop( caller: BaseAgent | AgentRuntime, tool_agent_id: AgentId, model_client: ChatCompletionClient, input_messages: List[LLMMessage], tool_schema: List[ToolSchema] | List[Tool], cancellation_token: CancellationToken | None = None, caller_source: str = "assistant", ) -> List[LLMMessage]: """Start a caller loop for a tool agent. This function sends messages to the tool agent and the model client in an alternating fashion until the model client stops generating tool calls. Args: tool_agent_id (AgentId): The Agent ID of the tool agent. input_messages (List[LLMMessage]): The list of input messages. model_client (ChatCompletionClient): The model client to use for the model API. tool_schema (List[Tool | ToolSchema]): The list of tools that the model can use. Returns: List[LLMMessage]: The list of output messages created in the caller loop. """ generated_messages: List[LLMMessage] = [] # Get a response from the model. response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token) # Add the response to the generated messages. generated_messages.append(AssistantMessage(content=response.content, source=caller_source)) # Keep iterating until the model stops generating tool calls. while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content): # Execute functions called by the model by sending messages to tool agent. results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( *[ caller.send_message( message=call, recipient=tool_agent_id, cancellation_token=cancellation_token, ) for call in response.content ], return_exceptions=True, ) # Combine the results into a single response and handle exceptions. function_results: List[FunctionExecutionResult] = [] for result in results: if isinstance(result, FunctionExecutionResult): function_results.append(result) elif isinstance(result, ToolException): function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)) elif isinstance(result, BaseException): raise result # Unexpected exception. generated_messages.append(FunctionExecutionResultMessage(content=function_results)) # Query the model again with the new response. response = await model_client.create( input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token ) generated_messages.append(AssistantMessage(content=response.content, source=caller_source)) # Return the generated messages. return generated_messages