Source code for autogen_agentchat.teams._group_chat._base_group_chat

import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Callable, List

from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import (
    AgentId,
    AgentInstantiationContext,
    AgentRuntime,
    AgentType,
    CancellationToken,
    MessageContext,
    TopicId,
)
from autogen_core.components import ClosureAgent, TypeSubscription

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, MultiModalMessage, TextMessage
from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination

event_logger = logging.getLogger(EVENT_LOGGER_NAME)


[docs] class BaseGroupChat(Team, ABC): """The base class for group chat teams. To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then create a subclass of :class:`BaseGroupChat` that uses the group chat manager. """ def __init__( self, participants: List[ChatAgent], group_chat_manager_class: type[BaseGroupChatManager], termination_condition: TerminationCondition | None = None, ): if len(participants) == 0: raise ValueError("At least one participant is required.") if len(participants) != len(set(participant.name for participant in participants)): raise ValueError("The participant names must be unique.") self._participants = participants self._base_group_chat_manager_class = group_chat_manager_class self._termination_condition = termination_condition # Constants for the group chat. self._team_id = str(uuid.uuid4()) self._group_topic_type = "group_topic" self._output_topic_type = "output_topic" self._group_chat_manager_topic_type = "group_chat_manager" self._participant_topic_types: List[str] = [participant.name for participant in participants] self._participant_descriptions: List[str] = [participant.description for participant in participants] self._collector_agent_type = "collect_output_messages" # Constants for the closure agent to collect the output messages. self._stop_reason: str | None = None self._output_message_queue: asyncio.Queue[AgentMessage | None] = asyncio.Queue() # Create a runtime for the team. # TODO: The runtime should be created by a managed context. self._runtime = SingleThreadedAgentRuntime() # Flag to track if the group chat has been initialized. self._initialized = False # Flag to track if the group chat is running. self._is_running = False @abstractmethod def _create_group_chat_manager_factory( self, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_descriptions: List[str], termination_condition: TerminationCondition | None, ) -> Callable[[], BaseGroupChatManager]: ... def _create_participant_factory( self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent, ) -> Callable[[], ChatAgentContainer]: def _factory() -> ChatAgentContainer: id = AgentInstantiationContext.current_agent_id() assert id == AgentId(type=agent.name, key=self._team_id) container = ChatAgentContainer(parent_topic_type, output_topic_type, agent) assert container.id == id return container return _factory async def _init(self, runtime: AgentRuntime) -> None: # Constants for the group chat manager. group_chat_manager_agent_type = AgentType(self._group_chat_manager_topic_type) # Register participants. for participant, participant_topic_type in zip(self._participants, self._participant_topic_types, strict=False): # Use the participant topic type as the agent type. agent_type = participant_topic_type # Register the participant factory. await ChatAgentContainer.register( runtime, type=agent_type, factory=self._create_participant_factory(self._group_topic_type, self._output_topic_type, participant), ) # Add subscriptions for the participant. await runtime.add_subscription(TypeSubscription(topic_type=participant_topic_type, agent_type=agent_type)) await runtime.add_subscription(TypeSubscription(topic_type=self._group_topic_type, agent_type=agent_type)) # Register the group chat manager. await self._base_group_chat_manager_class.register( runtime, type=group_chat_manager_agent_type.type, factory=self._create_group_chat_manager_factory( group_topic_type=self._group_topic_type, output_topic_type=self._output_topic_type, participant_topic_types=self._participant_topic_types, participant_descriptions=self._participant_descriptions, termination_condition=self._termination_condition, ), ) # Add subscriptions for the group chat manager. await runtime.add_subscription( TypeSubscription( topic_type=self._group_chat_manager_topic_type, agent_type=group_chat_manager_agent_type.type ) ) await runtime.add_subscription( TypeSubscription(topic_type=self._group_topic_type, agent_type=group_chat_manager_agent_type.type) ) async def collect_output_messages( _runtime: AgentRuntime, id: AgentId, message: GroupChatStart | GroupChatMessage | GroupChatTermination, ctx: MessageContext, ) -> None: event_logger.info(message.message) if isinstance(message, GroupChatTermination): self._stop_reason = message.message.content return await self._output_message_queue.put(message.message) await ClosureAgent.register( runtime, type=self._collector_agent_type, closure=collect_output_messages, subscriptions=lambda: [ TypeSubscription(topic_type=self._output_topic_type, agent_type=self._collector_agent_type), ], ) self._initialized = True
[docs] async def run( self, *, task: str | TextMessage | MultiModalMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> TaskResult: """Run the team and return the result. The base implementation uses :meth:`run_stream` to run the team and then returns the final result. Once the team is stopped, the termination condition is reset. Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.task import MaxMessageTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_ext.models import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent("Assistant1", model_client=model_client) agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = MaxMessageTermination(3) team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) result = await team.run(task="Count from 1 to 10, respond one at a time.") print(result) # Run the team again without a task to continue the previous task. result = await team.run() print(result) asyncio.run(main()) """ result: TaskResult | None = None async for message in self.run_stream( task=task, cancellation_token=cancellation_token, ): if isinstance(message, TaskResult): result = message if result is not None: return result raise AssertionError("The stream should have returned the final result.")
[docs] async def run_stream( self, *, task: str | TextMessage | MultiModalMessage | None = None, cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result of the type :class:`TaskResult` as the last item in the stream. Once the team is stopped, the termination condition is reset. Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.task import MaxMessageTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_ext.models import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent("Assistant1", model_client=model_client) agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = MaxMessageTermination(3) team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") async for message in stream: print(message) # Run the team again without a task to continue the previous task. stream = team.run_stream() async for message in stream: print(message) asyncio.run(main()) """ if self._is_running: raise ValueError("The team is already running, it cannot run again until it is stopped.") self._is_running = True # Start the runtime. # TODO: The runtime should be started by a managed context. self._runtime.start() if not self._initialized: await self._init(self._runtime) # Run the team by publishing the start message. first_chat_message: TextMessage | MultiModalMessage | None = None if isinstance(task, str): first_chat_message = TextMessage(content=task, source="user") elif isinstance(task, TextMessage | MultiModalMessage): first_chat_message = task await self._runtime.publish_message( GroupChatStart(message=first_chat_message), topic_id=TopicId(type=self._group_topic_type, source=self._team_id), ) # Start a coroutine to stop the runtime and signal the output message queue is complete. async def stop_runtime() -> None: await self._runtime.stop_when_idle() await self._output_message_queue.put(None) shutdown_task = asyncio.create_task(stop_runtime()) # Collect the output messages in order. output_messages: List[AgentMessage] = [] # Yield the messsages until the queue is empty. while True: message = await self._output_message_queue.get() if message is None: break yield message output_messages.append(message) # Wait for the shutdown task to finish. await shutdown_task # Yield the final result. yield TaskResult(messages=output_messages, stop_reason=self._stop_reason) # Indicate that the team is no longer running. self._is_running = False
[docs] async def reset(self) -> None: """Reset the team and its participants to their initial state. The team must be stopped before it can be reset. Raises: RuntimeError: If the team has not been initialized or is currently running. Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.task import MaxMessageTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_ext.models import OpenAIChatCompletionClient async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent("Assistant1", model_client=model_client) agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = MaxMessageTermination(3) team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") async for message in stream: print(message) # Reset the team. await team.reset() stream = team.run_stream(task="Count from 1 to 10, respond one at a time.") async for message in stream: print(message) asyncio.run(main()) """ if not self._initialized: raise RuntimeError("The group chat has not been initialized. It must be run before it can be reset.") if self._is_running: raise RuntimeError("The group chat is currently running. It must be stopped before it can be reset.") self._is_running = True # Start the runtime. self._runtime.start() # Send a reset message to the group chat. await self._runtime.publish_message( GroupChatReset(), topic_id=TopicId(type=self._group_topic_type, source=self._team_id), ) # Stop the runtime. await self._runtime.stop_when_idle() # Reset the output message queue. self._stop_reason = None while not self._output_message_queue.empty(): self._output_message_queue.get_nowait() # Indicate that the team is no longer running. self._is_running = False