Source code for autogen_agentchat.agents._message_filter_agent

from typing import AsyncGenerator, List, Literal, Optional, Sequence, Union

from autogen_core import CancellationToken, Component, ComponentModel
from pydantic import BaseModel

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage

# ------------------------------
# Message Filter Config
# ------------------------------


[docs] class PerSourceFilter(BaseModel): source: str position: Optional[Literal["first", "last"]] = None count: Optional[int] = None
[docs] class MessageFilterConfig(BaseModel): per_source: List[PerSourceFilter]
# ------------------------------ # Component Config # ------------------------------ class MessageFilterAgentConfig(BaseModel): name: str wrapped_agent: ComponentModel filter: MessageFilterConfig # ------------------------------ # Message Filter Agent # ------------------------------
[docs] class MessageFilterAgent(BaseChatAgent, Component[MessageFilterAgentConfig]): """ A wrapper agent that filters incoming messages before passing them to the inner agent. .. warning:: This is an experimental feature, and the API will change in the future releases. This is useful in scenarios like multi-agent workflows where an agent should only process a subset of the full message history—for example, only the last message from each upstream agent, or only the first message from a specific source. Filtering is configured using :class:`MessageFilterConfig`, which supports: - Filtering by message source (e.g., only messages from "user" or another agent) - Selecting the first N or last N messages from each source - If position is `None`, all messages from that source are included This agent is compatible with both direct message passing and team-based execution such as :class:`~autogen_agentchat.teams.GraphFlow`. Example: >>> agent_a = MessageFilterAgent( ... name="A", ... wrapped_agent=some_other_agent, ... filter=MessageFilterConfig( ... per_source=[ ... PerSourceFilter(source="user", position="first", count=1), ... PerSourceFilter(source="B", position="last", count=2), ... ] ... ), ... ) Example use case with Graph: Suppose you have a looping multi-agent graph: A → B → A → B → C. You want: - A to only see the user message and the last message from B - B to see the user message, last message from A, and its own prior responses (for reflection) - C to see the user message and the last message from B Wrap the agents like so: >>> agent_a = MessageFilterAgent( ... name="A", ... wrapped_agent=agent_a_inner, ... filter=MessageFilterConfig( ... per_source=[ ... PerSourceFilter(source="user", position="first", count=1), ... PerSourceFilter(source="B", position="last", count=1), ... ] ... ), ... ) >>> agent_b = MessageFilterAgent( ... name="B", ... wrapped_agent=agent_b_inner, ... filter=MessageFilterConfig( ... per_source=[ ... PerSourceFilter(source="user", position="first", count=1), ... PerSourceFilter(source="A", position="last", count=1), ... PerSourceFilter(source="B", position="last", count=10), ... ] ... ), ... ) >>> agent_c = MessageFilterAgent( ... name="C", ... wrapped_agent=agent_c_inner, ... filter=MessageFilterConfig( ... per_source=[ ... PerSourceFilter(source="user", position="first", count=1), ... PerSourceFilter(source="B", position="last", count=1), ... ] ... ), ... ) Then define the graph: >>> graph = DiGraph( ... nodes={ ... "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]), ... "B": DiGraphNode( ... name="B", ... edges=[ ... DiGraphEdge(target="C", condition="exit"), ... DiGraphEdge(target="A", condition="loop"), ... ], ... ), ... "C": DiGraphNode(name="C", edges=[]), ... }, ... default_start_node="A", ... ) This will ensure each agent sees only what is needed for its decision or action logic. """ component_config_schema = MessageFilterAgentConfig component_provider_override = "autogen_agentchat.agents.MessageFilterAgent" def __init__( self, name: str, wrapped_agent: BaseChatAgent, filter: MessageFilterConfig, ): super().__init__(name=name, description=f"{wrapped_agent.description} (with message filtering)") self._wrapped_agent = wrapped_agent self._filter = filter @property def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: return self._wrapped_agent.produced_message_types def _apply_filter(self, messages: Sequence[BaseChatMessage]) -> Sequence[BaseChatMessage]: result: List[BaseChatMessage] = [] for source_filter in self._filter.per_source: msgs = [m for m in messages if m.source == source_filter.source] if source_filter.position == "first" and source_filter.count: msgs = msgs[: source_filter.count] elif source_filter.position == "last" and source_filter.count: msgs = msgs[-source_filter.count :] result.extend(msgs) return result
[docs] async def on_messages( self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken, ) -> Response: filtered = self._apply_filter(messages) return await self._wrapped_agent.on_messages(filtered, cancellation_token)
[docs] async def on_messages_stream( self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken, ) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]: filtered = self._apply_filter(messages) async for item in self._wrapped_agent.on_messages_stream(filtered, cancellation_token): yield item
[docs] async def on_reset(self, cancellation_token: CancellationToken) -> None: await self._wrapped_agent.on_reset(cancellation_token)
[docs] def _to_config(self) -> MessageFilterAgentConfig: return MessageFilterAgentConfig( name=self.name, wrapped_agent=self._wrapped_agent.dump_component(), filter=self._filter, )
[docs] @classmethod def _from_config(cls, config: MessageFilterAgentConfig) -> "MessageFilterAgent": wrapped = BaseChatAgent.load_component(config.wrapped_agent) return cls( name=config.name, wrapped_agent=wrapped, filter=config.filter, )