Source code for autogen_ext.models.replay._replay_chat_completion_client

from __future__ import annotations

import logging
import warnings
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Sequence, Union

from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component
from autogen_core.models import (
    ChatCompletionClient,
    CreateResult,
    LLMMessage,
    ModelCapabilities,  # type: ignore
    ModelFamily,
    ModelInfo,
    RequestUsage,
    validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from typing_extensions import Self

logger = logging.getLogger(EVENT_LOGGER_NAME)


class ReplayChatCompletionClientConfig(BaseModel):
    """ReplayChatCompletionClient configuration."""

    chat_completions: Sequence[Union[str, CreateResult]]
    model_info: Optional[ModelInfo] = None


[docs] class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompletionClientConfig]): """ A mock chat completion client that replays predefined responses using an index-based approach. This class simulates a chat completion client by replaying a predefined list of responses. It supports both single completion and streaming responses. The responses can be either strings or CreateResult objects. The client now uses an index-based approach to access the responses, allowing for resetting the state. .. note:: The responses can be either strings or CreateResult objects. Args: chat_completions (Sequence[Union[str, CreateResult]]): A list of predefined responses to replay. Raises: ValueError("No more mock responses available"): If the list of provided outputs are exhausted. Examples: Simple chat completion client to return pre-defined responses. .. code-block:: python from autogen_core.models import UserMessage from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): chat_completions = [ "Hello, how can I assist you today?", "I'm happy to help with any questions you have.", "Is there anything else I can assist you with?", ] client = ReplayChatCompletionClient(chat_completions) messages = [UserMessage(content="What can you do?", source="user")] response = await client.create(messages) print(response.content) # Output: "Hello, how can I assist you today?" Simple streaming chat completion client to return pre-defined responses .. code-block:: python import asyncio from autogen_core.models import UserMessage from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): chat_completions = [ "Hello, how can I assist you today?", "I'm happy to help with any questions you have.", "Is there anything else I can assist you with?", ] client = ReplayChatCompletionClient(chat_completions) messages = [UserMessage(content="What can you do?", source="user")] async for token in client.create_stream(messages): print(token, end="") # Output: "Hello, how can I assist you today?" async for token in client.create_stream(messages): print(token, end="") # Output: "I'm happy to help with any questions you have." asyncio.run(example()) Using `.reset` to reset the chat client state .. code-block:: python import asyncio from autogen_core.models import UserMessage from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): chat_completions = [ "Hello, how can I assist you today?", ] client = ReplayChatCompletionClient(chat_completions) messages = [UserMessage(content="What can you do?", source="user")] response = await client.create(messages) print(response.content) # Output: "Hello, how can I assist you today?" response = await client.create(messages) # Raises ValueError("No more mock responses available") client.reset() # Reset the client state (current index of message and token usages) response = await client.create(messages) print(response.content) # Output: "Hello, how can I assist you today?" again asyncio.run(example()) """ __protocol__: ChatCompletionClient component_type = "replay_chat_completion_client" component_provider_override = "autogen_ext.models.replay.ReplayChatCompletionClient" component_config_schema = ReplayChatCompletionClientConfig # TODO: Support logprobs in Responses def __init__( self, chat_completions: Sequence[Union[str, CreateResult]], model_info: Optional[ModelInfo] = None, ): self.chat_completions = list(chat_completions) self.provided_message_count = len(self.chat_completions) if model_info is not None: self._model_info = model_info validate_model_info(self._model_info) else: self._model_info = ModelInfo( vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN ) self._total_available_tokens = 10000 self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._current_index = 0 self._cached_bool_value = True self._create_calls: List[Dict[str, Any]] = [] @property def create_calls(self) -> List[Dict[str, Any]]: """Return the arguments of the calls made to the create method.""" return self._create_calls
[docs] async def create( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> CreateResult: """Return the next completion from the list.""" if self._current_index >= len(self.chat_completions): raise ValueError("No more mock responses available") response = self.chat_completions[self._current_index] _, prompt_token_count = self._tokenize(messages) if isinstance(response, str): _, output_token_count = self._tokenize(response) self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count) response = CreateResult( finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value ) else: self._cur_usage = RequestUsage( prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens ) self._update_total_usage() self._current_index += 1 self._create_calls.append( { "messages": messages, "tools": tools, "json_output": json_output, "extra_create_args": extra_create_args, "cancellation_token": cancellation_token, } ) return response
[docs] async def create_stream( self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = [], json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> AsyncGenerator[Union[str, CreateResult], None]: """Return the next completion as a stream.""" if self._current_index >= len(self.chat_completions): raise ValueError("No more mock responses available") response = self.chat_completions[self._current_index] _, prompt_token_count = self._tokenize(messages) if isinstance(response, str): output_tokens, output_token_count = self._tokenize(response) self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count) for i, token in enumerate(output_tokens): if i < len(output_tokens) - 1: yield token + " " else: yield token yield CreateResult( finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value ) self._update_total_usage() else: self._cur_usage = RequestUsage( prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens ) yield response self._update_total_usage() self._current_index += 1
[docs] async def close(self) -> None: pass
[docs] def actual_usage(self) -> RequestUsage: return self._cur_usage
[docs] def total_usage(self) -> RequestUsage: return self._total_usage
[docs] def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: _, token_count = self._tokenize(messages) return token_count
[docs] def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: return max( 0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens )
[docs] def set_cached_bool_value(self, value: bool) -> None: self._cached_bool_value = value
def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]: total_tokens = 0 all_tokens: List[str] = [] if isinstance(messages, str): tokens = messages.split() total_tokens += len(tokens) all_tokens.extend(tokens) elif hasattr(messages, "content"): if isinstance(messages.content, str): # type: ignore [reportAttributeAccessIssue] tokens = messages.content.split() # type: ignore [reportAttributeAccessIssue] total_tokens += len(tokens) all_tokens.extend(tokens) else: logger.warning("Token count has been done only on string content") elif isinstance(messages, Sequence): for message in messages: if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr] tokens = message.content.split() # type: ignore [reportAttributeAccessIssue, union-attr] total_tokens += len(tokens) all_tokens.extend(tokens) else: logger.warning("Token count has been done only on string content") return all_tokens, total_tokens def _update_total_usage(self) -> None: self._total_usage.completion_tokens += self._cur_usage.completion_tokens self._total_usage.prompt_tokens += self._cur_usage.prompt_tokens @property def capabilities(self) -> ModelCapabilities: # type: ignore """Return mock capabilities.""" warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) return self._model_info @property def model_info(self) -> ModelInfo: return self._model_info
[docs] def reset(self) -> None: """Reset the client state and usage to its initial state.""" self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._current_index = 0
[docs] def _to_config(self) -> ReplayChatCompletionClientConfig: return ReplayChatCompletionClientConfig( chat_completions=self.chat_completions, model_info=self._model_info, )
[docs] @classmethod def _from_config(cls, config: ReplayChatCompletionClientConfig) -> Self: return cls( chat_completions=config.chat_completions, model_info=config.model_info, )