Custom Agents#

You may have agents with behaviors that do not fall into a preset. In such cases, you can build custom agents.

All agents in AgentChat inherit from BaseChatAgent class and implement the following abstract methods and attributes:

  • on_messages(): The abstract method that defines the behavior of the agent in response to messages. This method is called when the agent is asked to provide a response in run(). It returns a Response object.

  • on_reset(): The abstract method that resets the agent to its initial state. This method is called when the agent is asked to reset itself.

  • produced_message_types: The list of possible ChatMessage message types the agent can produce in its response.

Optionally, you can implement the the on_messages_stream() method to stream messages as they are generated by the agent. If this method is not implemented, the agent uses the default implementation of on_messages_stream() that calls the on_messages() method and yields all messages in the response.

CountDownAgent#

In this example, we create a simple agent that counts down from a given number to zero, and produces a stream of messages with the current count.

from typing import AsyncGenerator, List, Sequence, Tuple

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, TextMessage
from autogen_core import CancellationToken


class CountDownAgent(BaseChatAgent):
    def __init__(self, name: str, count: int = 3):
        super().__init__(name, "A simple agent that counts down.")
        self._count = count

    @property
    def produced_message_types(self) -> Sequence[type[ChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
        # Calls the on_messages_stream.
        response: Response | None = None
        async for message in self.on_messages_stream(messages, cancellation_token):
            if isinstance(message, Response):
                response = message
        assert response is not None
        return response

    async def on_messages_stream(
        self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
    ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
        inner_messages: List[AgentEvent | ChatMessage] = []
        for i in range(self._count, 0, -1):
            msg = TextMessage(content=f"{i}...", source=self.name)
            inner_messages.append(msg)
            yield msg
        # The response is returned at the end of the stream.
        # It contains the final message and all the inner messages.
        yield Response(chat_message=TextMessage(content="Done!", source=self.name), inner_messages=inner_messages)

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        pass


async def run_countdown_agent() -> None:
    # Create a countdown agent.
    countdown_agent = CountDownAgent("countdown")

    # Run the agent with a given task and stream the response.
    async for message in countdown_agent.on_messages_stream([], CancellationToken()):
        if isinstance(message, Response):
            print(message.chat_message.content)
        else:
            print(message.content)


# Use asyncio.run(run_countdown_agent()) when running in a script.
await run_countdown_agent()
3...
2...
1...
Done!

ArithmeticAgent#

In this example, we create an agent class that can perform simple arithmetic operations on a given integer. Then, we will use different instances of this agent class in a SelectorGroupChat to transform a given integer into another integer by applying a sequence of arithmetic operations.

The ArithmeticAgent class takes an operator_func that takes an integer and returns an integer, after applying an arithmetic operation to the integer. In its on_messages method, it applies the operator_func to the integer in the input message, and returns a response with the result.

from typing import Callable, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.messages import ChatMessage
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient


class ArithmeticAgent(BaseChatAgent):
    def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:
        super().__init__(name, description=description)
        self._operator_func = operator_func
        self._message_history: List[ChatMessage] = []

    @property
    def produced_message_types(self) -> Sequence[type[ChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
        # Update the message history.
        # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.
        self._message_history.extend(messages)
        # Parse the number in the last message.
        assert isinstance(self._message_history[-1], TextMessage)
        number = int(self._message_history[-1].content)
        # Apply the operator function to the number.
        result = self._operator_func(number)
        # Create a new message with the result.
        response_message = TextMessage(content=str(result), source=self.name)
        # Update the message history.
        self._message_history.append(response_message)
        # Return the response.
        return Response(chat_message=response_message)

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        pass

Note

The on_messages method may be called with an empty list of messages, in which case it means the agent was called previously and is now being called again, without any new messages from the caller. So it is important to keep a history of the previous messages received by the agent, and use that history to generate the response.

Now we can create a SelectorGroupChat with 5 instances of ArithmeticAgent:

  • one that adds 1 to the input integer,

  • one that subtracts 1 from the input integer,

  • one that multiplies the input integer by 2,

  • one that divides the input integer by 2 and rounds down to the nearest integer, and

  • one that returns the input integer unchanged.

We then create a SelectorGroupChat with these agents, and set the appropriate selector settings:

  • allow the same agent to be selected consecutively to allow for repeated operations, and

  • customize the selector prompt to tailor the model’s response to the specific task.

async def run_number_agents() -> None:
    # Create agents for number operations.
    add_agent = ArithmeticAgent("add_agent", "Adds 1 to the number.", lambda x: x + 1)
    multiply_agent = ArithmeticAgent("multiply_agent", "Multiplies the number by 2.", lambda x: x * 2)
    subtract_agent = ArithmeticAgent("subtract_agent", "Subtracts 1 from the number.", lambda x: x - 1)
    divide_agent = ArithmeticAgent("divide_agent", "Divides the number by 2 and rounds down.", lambda x: x // 2)
    identity_agent = ArithmeticAgent("identity_agent", "Returns the number as is.", lambda x: x)

    # The termination condition is to stop after 10 messages.
    termination_condition = MaxMessageTermination(10)

    # Create a selector group chat.
    selector_group_chat = SelectorGroupChat(
        [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],
        model_client=OpenAIChatCompletionClient(model="gpt-4o"),
        termination_condition=termination_condition,
        allow_repeated_speaker=True,  # Allow the same agent to speak multiple times, necessary for this task.
        selector_prompt=(
            "Available roles:\n{roles}\nTheir job descriptions:\n{participants}\n"
            "Current conversation history:\n{history}\n"
            "Please select the most appropriate role for the next message, and only return the role name."
        ),
    )

    # Run the selector group chat with a given task and stream the response.
    task: List[ChatMessage] = [
        TextMessage(content="Apply the operations to turn the given number into 25.", source="user"),
        TextMessage(content="10", source="user"),
    ]
    stream = selector_group_chat.run_stream(task=task)
    await Console(stream)


# Use asyncio.run(run_number_agents()) when running in a script.
await run_number_agents()
---------- user ----------
Apply the operations to turn the given number into 25.
---------- user ----------
10
---------- multiply_agent ----------
20
---------- add_agent ----------
21
---------- multiply_agent ----------
42
---------- divide_agent ----------
21
---------- add_agent ----------
22
---------- add_agent ----------
23
---------- add_agent ----------
24
---------- add_agent ----------
25
---------- Summary ----------
Number of messages: 10
Finish reason: Maximum number of messages 10 reached, current message count: 10
Total prompt tokens: 0
Total completion tokens: 0
Duration: 2.40 seconds

From the output, we can see that the agents have successfully transformed the input integer from 10 to 25 by choosing appropriate agents that apply the arithmetic operations in sequence.