Source code for autogen_core.tool_agent._caller_loop
import asyncio
from typing import List
from .. import AgentId, AgentRuntime, BaseAgent, CancellationToken, FunctionCall
from ..models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
)
from ..tools import Tool, ToolSchema
from ._tool_agent import ToolException
[docs]
async def tool_agent_caller_loop(
caller: BaseAgent | AgentRuntime,
tool_agent_id: AgentId,
model_client: ChatCompletionClient,
input_messages: List[LLMMessage],
tool_schema: List[ToolSchema] | List[Tool],
cancellation_token: CancellationToken | None = None,
caller_source: str = "assistant",
) -> List[LLMMessage]:
"""Start a caller loop for a tool agent. This function sends messages to the tool agent
and the model client in an alternating fashion until the model client stops generating tool calls.
Args:
tool_agent_id (AgentId): The Agent ID of the tool agent.
input_messages (List[LLMMessage]): The list of input messages.
model_client (ChatCompletionClient): The model client to use for the model API.
tool_schema (List[Tool | ToolSchema]): The list of tools that the model can use.
Returns:
List[LLMMessage]: The list of output messages created in the caller loop.
"""
generated_messages: List[LLMMessage] = []
# Get a response from the model.
response = await model_client.create(input_messages, tools=tool_schema, cancellation_token=cancellation_token)
# Add the response to the generated messages.
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Keep iterating until the model stops generating tool calls.
while isinstance(response.content, list) and all(isinstance(item, FunctionCall) for item in response.content):
# Execute functions called by the model by sending messages to tool agent.
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
caller.send_message(
message=call,
recipient=tool_agent_id,
cancellation_token=cancellation_token,
)
for call in response.content
],
return_exceptions=True,
)
# Combine the results into a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
generated_messages.append(FunctionExecutionResultMessage(content=function_results))
# Query the model again with the new response.
response = await model_client.create(
input_messages + generated_messages, tools=tool_schema, cancellation_token=cancellation_token
)
generated_messages.append(AssistantMessage(content=response.content, source=caller_source))
# Return the generated messages.
return generated_messages