Source code for autogen_agentchat.base._termination

import asyncio
from abc import ABC, abstractmethod
from typing import List, Sequence

from ..messages import AgentMessage, StopMessage


[docs] class TerminatedException(BaseException): ...
[docs] class TerminationCondition(ABC): """A stateful condition that determines when a conversation should be terminated. A termination condition is a callable that takes a sequence of ChatMessage objects since the last time the condition was called, and returns a StopMessage if the conversation should be terminated, or None otherwise. Once a termination condition has been reached, it must be reset before it can be used again. Termination conditions can be combined using the AND and OR operators. Example: .. code-block:: python import asyncio from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination async def main() -> None: # Terminate the conversation after 10 turns or if the text "TERMINATE" is mentioned. cond1 = MaxMessageTermination(10) | TextMentionTermination("TERMINATE") # Terminate the conversation after 10 turns and if the text "TERMINATE" is mentioned. cond2 = MaxMessageTermination(10) & TextMentionTermination("TERMINATE") # ... # Reset the termination condition. await cond1.reset() await cond2.reset() asyncio.run(main()) """ @property @abstractmethod def terminated(self) -> bool: """Check if the termination condition has been reached""" ... @abstractmethod async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None: """Check if the conversation should be terminated based on the messages received since the last time the condition was called. Return a StopMessage if the conversation should be terminated, or None otherwise. Args: messages: The messages received since the last time the condition was called. Returns: StopMessage | None: A StopMessage if the conversation should be terminated, or None otherwise. Raises: TerminatedException: If the termination condition has already been reached.""" ...
[docs] @abstractmethod async def reset(self) -> None: """Reset the termination condition.""" ...
def __and__(self, other: "TerminationCondition") -> "TerminationCondition": """Combine two termination conditions with an AND operation.""" return _AndTerminationCondition(self, other) def __or__(self, other: "TerminationCondition") -> "TerminationCondition": """Combine two termination conditions with an OR operation.""" return _OrTerminationCondition(self, other)
class _AndTerminationCondition(TerminationCondition): def __init__(self, *conditions: TerminationCondition) -> None: self._conditions = conditions self._stop_messages: List[StopMessage] = [] @property def terminated(self) -> bool: return all(condition.terminated for condition in self._conditions) async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None: if self.terminated: raise TerminatedException("Termination condition has already been reached.") # Check all remaining conditions. stop_messages = await asyncio.gather( *[condition(messages) for condition in self._conditions if not condition.terminated] ) # Collect stop messages. for stop_message in stop_messages: if stop_message is not None: self._stop_messages.append(stop_message) if any(stop_message is None for stop_message in stop_messages): # If any remaining condition has not reached termination, it is not terminated. return None content = ", ".join(stop_message.content for stop_message in self._stop_messages) source = ", ".join(stop_message.source for stop_message in self._stop_messages) return StopMessage(content=content, source=source) async def reset(self) -> None: for condition in self._conditions: await condition.reset() self._stop_messages.clear() class _OrTerminationCondition(TerminationCondition): def __init__(self, *conditions: TerminationCondition) -> None: self._conditions = conditions @property def terminated(self) -> bool: return any(condition.terminated for condition in self._conditions) async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None: if self.terminated: raise RuntimeError("Termination condition has already been reached") stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions]) if any(stop_message is not None for stop_message in stop_messages): content = ", ".join(stop_message.content for stop_message in stop_messages if stop_message is not None) source = ", ".join(stop_message.source for stop_message in stop_messages if stop_message is not None) return StopMessage(content=content, source=source) return None async def reset(self) -> None: for condition in self._conditions: await condition.reset()