Source code for autogen_core.model_context._head_and_tail_chat_completion_context
from typing import List
from pydantic import BaseModel
from typing_extensions import Self
from .._component_config import Component
from .._types import FunctionCall
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
from ._chat_completion_context import ChatCompletionContext
class HeadAndTailChatCompletionContextConfig(BaseModel):
head_size: int
tail_size: int
initial_messages: List[LLMMessage] | None = None
[docs]
class HeadAndTailChatCompletionContext(ChatCompletionContext, Component[HeadAndTailChatCompletionContextConfig]):
"""A chat completion context that keeps a view of the first n and last m messages,
where n is the head size and m is the tail size. The head and tail sizes
are set at initialization.
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
initial_messages (List[LLMMessage] | None): The initial messages.
"""
component_config_schema = HeadAndTailChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.HeadAndTailChatCompletionContext"
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if head_size <= 0:
raise ValueError("head_size must be greater than 0.")
if tail_size <= 0:
raise ValueError("tail_size must be greater than 0.")
self._head_size = head_size
self._tail_size = tail_size
[docs]
async def get_messages(self) -> List[LLMMessage]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
# Handle the last message is a function call message.
if (
head_messages
and isinstance(head_messages[-1], AssistantMessage)
and isinstance(head_messages[-1].content, list)
and all(isinstance(item, FunctionCall) for item in head_messages[-1].content)
):
# Remove the last message from the head.
head_messages = head_messages[:-1]
tail_messages = self._messages[-self._tail_size :]
# Handle the first message is a function call result message.
if tail_messages and isinstance(tail_messages[0], FunctionExecutionResultMessage):
# Remove the first message from the tail.
tail_messages = tail_messages[1:]
num_skipped = len(self._messages) - self._head_size - self._tail_size
if num_skipped <= 0:
# If there are not enough messages to fill the head and tail,
# return all messages.
return self._messages
placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages
[docs]
def _to_config(self) -> HeadAndTailChatCompletionContextConfig:
return HeadAndTailChatCompletionContextConfig(
head_size=self._head_size, tail_size=self._tail_size, initial_messages=self._initial_messages
)
[docs]
@classmethod
def _from_config(cls, config: HeadAndTailChatCompletionContextConfig) -> Self:
return cls(head_size=config.head_size, tail_size=config.tail_size, initial_messages=config.initial_messages)