import asyncio
import logging
import re
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self
from ... import TRACE_LOGGER_NAME
from ...agents import BaseChatAgent
from ...base import ChatAgent, TerminationCondition
from ...messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
SelectorEvent,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
from ._events import GroupChatTermination
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
SyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None]
AsyncSelectorFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]]
SelectorFuncType = Union[SyncSelectorFunc | AsyncSelectorFunc]
SyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]]
AsyncCandidateFunc = Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]]
CandidateFuncType = Union[SyncCandidateFunc | AsyncCandidateFunc]
class SelectorGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker using a ChatCompletion
model and a custom selector function."""
def __init__(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
model_client: ChatCompletionClient,
selector_prompt: str,
allow_repeated_speaker: bool,
selector_func: Optional[SelectorFuncType],
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool,
model_context: ChatCompletionContext | None,
model_client_streaming: bool = False,
) -> None:
super().__init__(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
emit_team_events,
)
self._model_client = model_client
self._selector_prompt = selector_prompt
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._is_selector_func_async = iscoroutinefunction(self._selector_func)
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._model_client_streaming = model_client_streaming
if model_context is not None:
self._model_context = model_context
else:
self._model_context = UnboundedChatCompletionContext()
self._cancellation_token = CancellationToken()
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def reset(self) -> None:
self._current_turn = 0
self._message_thread.clear()
await self._model_context.clear()
if self._termination_condition is not None:
await self._termination_condition.reset()
self._previous_speaker = None
async def save_state(self) -> Mapping[str, Any]:
state = SelectorManagerState(
message_thread=[msg.dump() for msg in self._message_thread],
current_turn=self._current_turn,
previous_speaker=self._previous_speaker,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
selector_state = SelectorManagerState.model_validate(state)
self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread]
await self._add_messages_to_context(
self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)]
)
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker
@staticmethod
async def _add_messages_to_context(
model_context: ChatCompletionContext,
messages: Sequence[BaseChatMessage],
) -> None:
"""
Add incoming messages to the model context.
"""
for msg in messages:
if isinstance(msg, HandoffMessage):
for llm_msg in msg.context:
await model_context.add_message(llm_msg)
await model_context.add_message(msg.to_model_message())
async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
self._message_thread.extend(messages)
base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)]
await self._add_messages_to_context(self._model_context, base_chat_messages)
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
.. note::
This method always returns a single speaker name.
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
"""
# Use the selector function if provided.
if self._selector_func is not None:
if self._is_selector_func_async:
async_selector_func = cast(AsyncSelectorFunc, self._selector_func)
speaker = await async_selector_func(thread)
else:
sync_selector_func = cast(SyncSelectorFunc, self._selector_func)
speaker = sync_selector_func(thread)
if speaker is not None:
if speaker not in self._participant_names:
raise ValueError(
f"Selector function returned an invalid speaker name: {speaker}. "
f"Expected one of: {self._participant_names}."
)
# Skip the model based selection.
return [speaker]
# Use the candidate function to filter participants if provided
if self._candidate_func is not None:
if self._is_candidate_func_async:
async_candidate_func = cast(AsyncCandidateFunc, self._candidate_func)
participants = await async_candidate_func(thread)
else:
sync_candidate_func = cast(SyncCandidateFunc, self._candidate_func)
participants = sync_candidate_func(thread)
if not participants:
raise ValueError("Candidate function must return a non-empty list of participant names.")
if not all(p in self._participant_names for p in participants):
raise ValueError(
f"Candidate function returned invalid participant names: {participants}. "
f"Expected one of: {self._participant_names}."
)
else:
# Construct the candidate agent list to be selected from, skip the previous speaker if not allowed.
if self._previous_speaker is not None and not self._allow_repeated_speaker:
participants = [p for p in self._participant_names if p != self._previous_speaker]
else:
participants = list(self._participant_names)
assert len(participants) > 0
# Construct agent roles.
# Each agent sould appear on a single line.
roles = ""
for topic_type, description in zip(self._participant_names, self._participant_descriptions, strict=True):
roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n"
roles = roles.strip()
# Select the next speaker.
if len(participants) > 1:
agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts)
else:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")
return [agent_name]
def construct_message_history(self, message_history: List[LLMMessage]) -> str:
# Construct the history of the conversation.
history_messages: List[str] = []
for msg in message_history:
if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage):
message = f"{msg.source}: {msg.content}"
history_messages.append(
message.rstrip() + "\n\n"
) # Create some consistency for how messages are separated in the transcript
history: str = "\n".join(history_messages)
return history
async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str:
model_context_messages = await self._model_context.get_messages()
model_context_history = self.construct_message_history(model_context_messages)
select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=model_context_history
)
select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage]
if ModelFamily.is_openai(self._model_client.model_info["family"]):
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
else:
# Many other models need a UserMessage to respond to
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="user")]
num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
if self._model_client_streaming:
chunk: CreateResult | str = ""
async for _chunk in self._model_client.create_stream(messages=select_speaker_messages):
chunk = _chunk
if self._emit_team_events:
if isinstance(chunk, str):
await self._output_message_queue.put(
ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name)
)
else:
assert isinstance(chunk, CreateResult)
assert isinstance(chunk.content, str)
await self._output_message_queue.put(
SelectorEvent(content=chunk.content, source=self._name)
)
# The last chunk must be CreateResult.
assert isinstance(chunk, CreateResult)
response = chunk
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
# This is because the model may still select the previous speaker, and we want to catch that.
mentions = self._mentioned_agents(response.content, self._participant_names)
if len(mentions) == 0:
trace_logger.debug(f"Model failed to select a valid name: {response.content} (attempt {num_attempts})")
feedback = f"No valid name was mentioned. Please select from: {str(participants)}."
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
elif len(mentions) > 1:
trace_logger.debug(f"Model selected multiple names: {str(mentions)} (attempt {num_attempts})")
feedback = (
f"Expected exactly one name to be mentioned. Please select only one from: {str(participants)}."
)
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
else:
agent_name = list(mentions.keys())[0]
if (
not self._allow_repeated_speaker
and self._previous_speaker is not None
and agent_name == self._previous_speaker
):
trace_logger.debug(f"Model selected the previous speaker: {agent_name} (attempt {num_attempts})")
feedback = (
f"Repeated speaker is not allowed, please select a different name from: {str(participants)}."
)
select_speaker_messages.append(UserMessage(content=feedback, source="user"))
else:
# Valid selection
trace_logger.debug(f"Model selected a valid name: {agent_name} (attempt {num_attempts})")
return agent_name
if self._previous_speaker is not None:
trace_logger.warning(f"Model failed to select a speaker after {max_attempts}, using the previous speaker.")
return self._previous_speaker
trace_logger.warning(
f"Model failed to select a speaker after {max_attempts} and there was no previous speaker, using the first participant."
)
return participants[0]
def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
- Exact name match
- If the agent name has underscores it will match with spaces instead (e.g. 'Story_writer' == 'Story writer')
- If the agent name has underscores it will match with '\\_' instead of '_' (e.g. 'Story_writer' == 'Story\\_writer')
Args:
message_content (Union[str, List]): The content of the message, either as a single string or a list of strings.
agents (List[Agent]): A list of Agent objects, each having a 'name' attribute to be searched in the message content.
Returns:
Dict: a counter for mentioned agents.
"""
mentions: Dict[str, int] = dict()
for name in agent_names:
# Finds agent mentions, taking word boundaries into account,
# accommodates escaping underscores and underscores as spaces
regex = (
r"(?<=\W)("
+ re.escape(name)
+ r"|"
+ re.escape(name.replace("_", " "))
+ r"|"
+ re.escape(name.replace("_", r"\_"))
+ r")(?=\W)"
)
# Pad the message to help with matching
count = len(re.findall(regex, f" {message_content} "))
if count > 0:
mentions[name] = count
return mentions
class SelectorGroupChatConfig(BaseModel):
"""The declarative configuration for SelectorGroupChat."""
participants: List[ComponentModel]
model_client: ComponentModel
termination_condition: ComponentModel | None = None
max_turns: int | None = None
selector_prompt: str
allow_repeated_speaker: bool
# selector_func: ComponentModel | None
max_selector_attempts: int = 3
emit_team_events: bool = False
model_client_streaming: bool = False
model_context: ComponentModel | None = None
[docs]
class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
"""A group chat team that have participants takes turn to publish a message
to all, using a ChatCompletion model to select the next speaker after each message.
Args:
participants (List[ChatAgent]): The participants in the group chat,
must have unique names and at least two participants.
model_client (ChatCompletionClient): The ChatCompletion model client used
to select the next speaker.
termination_condition (TerminationCondition, optional): The termination condition for the group chat. Defaults to None.
Without a termination condition, the group chat will run indefinitely.
max_turns (int, optional): The maximum number of turns in the group chat before stopping. Defaults to None, meaning no limit.
selector_prompt (str, optional): The prompt template to use for selecting the next speaker.
Available fields: '{roles}', '{participants}', and '{history}'.
`{participants}` is the names of candidates for selection. The format is `["<name1>", "<name2>", ...]`.
`{roles}` is a newline-separated list of names and descriptions of the candidate agents. The format for each line is: `"<name> : <description>"`.
`{history}` is the conversation history formatted as a double newline separated of names and message content. The format for each message is: `"<name> : <message content>"`.
allow_repeated_speaker (bool, optional): Whether to include the previous speaker in the list of candidates to be selected for the next turn.
Defaults to False. The model may still select the previous speaker -- a warning will be logged if this happens.
max_selector_attempts (int, optional): The maximum number of attempts to select a speaker using the model. Defaults to 3.
If the model fails to select a speaker after the maximum number of attempts, the previous speaker will be used if available,
otherwise the first participant will be used.
selector_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[str | None]], optional): A custom selector
function that takes the conversation history and returns the name of the next speaker.
If provided, this function will be used to override the model to select the next speaker.
If the function returns None, the model will be used to select the next speaker.
candidate_func (Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], List[str]], Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], Awaitable[List[str]]], optional):
A custom function that takes the conversation history and returns a filtered list of candidates for the next speaker
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
custom_message_types (List[type[BaseAgentEvent | BaseChatMessage]], optional): A list of custom message types that will be used in the group chat.
If you are using custom message types or your agents produces custom message types, you need to specify them here.
Make sure your custom message types are subclasses of :class:`~autogen_agentchat.messages.BaseAgentEvent` or :class:`~autogen_agentchat.messages.BaseChatMessage`.
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving
:class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset.
Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Examples:
A team with multiple participants:
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
[travel_advisor, hotel_agent, flight_agent],
model_client=model_client,
termination_condition=termination,
)
await Console(team.run_stream(task="Book a 3-day trip to new york."))
asyncio.run(main())
A team with a custom selector function:
.. code-block:: python
import asyncio
from typing import Sequence
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.ui import Console
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
def check_calculation(x: int, y: int, answer: int) -> str:
if x + y == answer:
return "Correct!"
else:
return "Incorrect!"
agent1 = AssistantAgent(
"Agent1",
model_client,
description="For calculation",
system_message="Calculate the sum of two numbers",
)
agent2 = AssistantAgent(
"Agent2",
model_client,
tools=[check_calculation],
description="For checking calculation",
system_message="Check the answer and respond with 'Correct!' or 'Incorrect!'",
)
def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | None:
if len(messages) == 1 or messages[-1].to_text() == "Incorrect!":
return "Agent1"
if messages[-1].source == "Agent1":
return "Agent2"
return None
termination = TextMentionTermination("Correct!")
team = SelectorGroupChat(
[agent1, agent2],
model_client=model_client,
selector_func=selector_func,
termination_condition=termination,
)
await Console(team.run_stream(task="What is 1 + 1?"))
asyncio.run(main())
A team with custom model context:
.. code-block:: python
import asyncio
from autogen_core.model_context import BufferedChatCompletionContext
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
model_context = BufferedChatCompletionContext(buffer_size=5)
async def lookup_hotel(location: str) -> str:
return f"Here are some hotels in {location}: hotel1, hotel2, hotel3."
async def lookup_flight(origin: str, destination: str) -> str:
return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3."
async def book_trip() -> str:
return "Your trip is booked!"
travel_advisor = AssistantAgent(
"Travel_Advisor",
model_client,
tools=[book_trip],
description="Helps with travel planning.",
)
hotel_agent = AssistantAgent(
"Hotel_Agent",
model_client,
tools=[lookup_hotel],
description="Helps with hotel booking.",
)
flight_agent = AssistantAgent(
"Flight_Agent",
model_client,
tools=[lookup_flight],
description="Helps with flight booking.",
)
termination = TextMentionTermination("TERMINATE")
team = SelectorGroupChat(
[travel_advisor, hotel_agent, flight_agent],
model_client=model_client,
termination_condition=termination,
model_context=model_context,
)
await Console(team.run_stream(task="Book a 3-day trip to new york."))
asyncio.run(main())
"""
component_config_schema = SelectorGroupChatConfig
component_provider_override = "autogen_agentchat.teams.SelectorGroupChat"
def __init__(
self,
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
termination_condition: TerminationCondition | None = None,
max_turns: int | None = None,
runtime: AgentRuntime | None = None,
selector_prompt: str = """You are in a role play game. The following roles are available:
{roles}.
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
{history}
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
""",
allow_repeated_speaker: bool = False,
max_selector_attempts: int = 3,
selector_func: Optional[SelectorFuncType] = None,
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_client_streaming: bool = False,
model_context: ChatCompletionContext | None = None,
):
super().__init__(
participants,
group_chat_manager_name="SelectorGroupChatManager",
group_chat_manager_class=SelectorGroupChatManager,
termination_condition=termination_condition,
max_turns=max_turns,
runtime=runtime,
custom_message_types=custom_message_types,
emit_team_events=emit_team_events,
)
# Validate the participants.
if len(participants) < 2:
raise ValueError("At least two participants are required for SelectorGroupChat.")
self._selector_prompt = selector_prompt
self._model_client = model_client
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming
self._model_context = model_context
def _create_group_chat_manager_factory(
self,
name: str,
group_topic_type: str,
output_topic_type: str,
participant_topic_types: List[str],
participant_names: List[str],
participant_descriptions: List[str],
output_message_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination],
termination_condition: TerminationCondition | None,
max_turns: int | None,
message_factory: MessageFactory,
) -> Callable[[], BaseGroupChatManager]:
return lambda: SelectorGroupChatManager(
name,
group_topic_type,
output_topic_type,
participant_topic_types,
participant_names,
participant_descriptions,
output_message_queue,
termination_condition,
max_turns,
message_factory,
self._model_client,
self._selector_prompt,
self._allow_repeated_speaker,
self._selector_func,
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
self._model_context,
self._model_client_streaming,
)
[docs]
def _to_config(self) -> SelectorGroupChatConfig:
return SelectorGroupChatConfig(
participants=[participant.dump_component() for participant in self._participants],
model_client=self._model_client.dump_component(),
termination_condition=self._termination_condition.dump_component() if self._termination_condition else None,
max_turns=self._max_turns,
selector_prompt=self._selector_prompt,
allow_repeated_speaker=self._allow_repeated_speaker,
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
model_context=self._model_context.dump_component() if self._model_context else None,
)
[docs]
@classmethod
def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
return cls(
participants=[BaseChatAgent.load_component(participant) for participant in config.participants],
model_client=ChatCompletionClient.load_component(config.model_client),
termination_condition=TerminationCondition.load_component(config.termination_condition)
if config.termination_condition
else None,
max_turns=config.max_turns,
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
max_selector_attempts=config.max_selector_attempts,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[BaseAgentEvent | BaseChatMessage]], str | None])
# if config.selector_func
# else None,
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
)