Source code for autogen_agentchat.teams._group_chat._round_robin_group_chat

import asyncio
from typing import Any, Callable, List, Mapping, Sequence

from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ...base import ChatAgent, TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination


class RoundRobinGroupChatManager(BaseGroupChatManager):
    """A group chat manager that selects the next speaker in a round-robin fashion."""

    def __init__(
        self,
        name: str,
        group_topic_type: str,
        output_topic_type: str,
        participant_topic_types: List[str],
        participant_names: List[str],
        participant_descriptions: List[str],
        output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
        termination_condition: TerminationCondition | None,
        max_turns: int | None,
        message_factory: MessageFactory,
        emit_team_events: bool,
    ) -> None:
        super().__init__(
            name,
            group_topic_type,
            output_topic_type,
            participant_topic_types,
            participant_names,
            participant_descriptions,
            output_message_queue,
            termination_condition,
            max_turns,
            message_factory,
            emit_team_events,
        )
        self._next_speaker_index = 0

    async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
        pass

    async def reset(self) -> None:
        self._current_turn = 0
        self._message_thread.clear()
        if self._termination_condition is not None:
            await self._termination_condition.reset()
        self._next_speaker_index = 0

    async def save_state(self) -> Mapping[str, Any]:
        state = RoundRobinManagerState(
            message_thread=[message.dump() for message in self._message_thread],
            current_turn=self._current_turn,
            next_speaker_index=self._next_speaker_index,
        )
        return state.model_dump()

    async def load_state(self, state: Mapping[str, Any]) -> None:
        round_robin_state = RoundRobinManagerState.model_validate(state)
        self._message_thread = [self._message_factory.create(message) for message in round_robin_state.message_thread]
        self._current_turn = round_robin_state.current_turn
        self._next_speaker_index = round_robin_state.next_speaker_index

    async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
        """Select a speaker from the participants in a round-robin fashion.

        .. note::

            This method always returns a single speaker.
        """
        current_speaker_index = self._next_speaker_index
        self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_names)
        current_speaker = self._participant_names[current_speaker_index]
        return current_speaker


class RoundRobinGroupChatConfig(BaseModel):
    """The declarative configuration RoundRobinGroupChat."""

    participants: List[ComponentModel]
    termination_condition: ComponentModel | None = None
    max_turns: int | None = None
    emit_team_events: bool = False


[docs] class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]): """A team that runs a group chat with participants taking turns in a round-robin fashion to publish a message to all. If a single participant is in the team, the participant will be the only speaker. Args: participants (List[BaseChatAgent]): The participants in the group chat. termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None. Without a termination condition, the group chat will run indefinitely. max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit. custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat. If you are using custom message types or your agents produces custom message types, you need to specify them here. Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. Raises: ValueError: If no participants are provided or if participant names are not unique. Examples: A team with one participant with tools: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.ui import Console async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") async def get_weather(location: str) -> str: return f"The weather in {location} is sunny." assistant = AssistantAgent( "Assistant", model_client=model_client, tools=[get_weather], ) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([assistant], termination_condition=termination) await Console(team.run_stream(task="What's the weather in New York?")) asyncio.run(main()) A team with multiple participants: .. code-block:: python import asyncio from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.ui import Console async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o") agent1 = AssistantAgent("Assistant1", model_client=model_client) agent2 = AssistantAgent("Assistant2", model_client=model_client) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination) await Console(team.run_stream(task="Tell me some jokes.")) asyncio.run(main()) """ component_config_schema = RoundRobinGroupChatConfig component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat" # TODO: Add * to the constructor to separate the positional parameters from the kwargs. # This may be a breaking change so let's wait until a good time to do it. def __init__( self, participants: List[ChatAgent], termination_condition: TerminationCondition | None = None, max_turns: int | None = None, runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, ) -> None: super().__init__( participants, group_chat_manager_name="RoundRobinGroupChatManager", group_chat_manager_class=RoundRobinGroupChatManager, termination_condition=termination_condition, max_turns=max_turns, runtime=runtime, custom_message_types=custom_message_types, emit_team_events=emit_team_events, ) def _create_group_chat_manager_factory( self, name: str, group_topic_type: str, output_topic_type: str, participant_topic_types: List[str], participant_names: List[str], participant_descriptions: List[str], output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination], termination_condition: TerminationCondition | None, max_turns: int | None, message_factory: MessageFactory, ) -> Callable[[], RoundRobinGroupChatManager]: def _factory() -> RoundRobinGroupChatManager: return RoundRobinGroupChatManager( name, group_topic_type, output_topic_type, participant_topic_types, participant_names, participant_descriptions, output_message_queue, termination_condition, max_turns, message_factory, self._emit_team_events, ) return _factory
[docs] def _to_config(self) -> RoundRobinGroupChatConfig: participants = [participant.dump_component() for participant in self._participants] termination_condition = self._termination_condition.dump_component() if self._termination_condition else None return RoundRobinGroupChatConfig( participants=participants, termination_condition=termination_condition, max_turns=self._max_turns, emit_team_events=self._emit_team_events, )
[docs] @classmethod def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self: participants = [ChatAgent.load_component(participant) for participant in config.participants] termination_condition = ( TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None ) return cls( participants, termination_condition=termination_condition, max_turns=config.max_turns, emit_team_events=config.emit_team_events, )