from __future__ import annotations
import logging
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union
from autogen_core import EVENT_LOGGER_NAME, CancellationToken
from autogen_core.models import (
    ChatCompletionClient,
    CreateResult,
    LLMMessage,
    ModelCapabilities,  # type: ignore
    ModelFamily,
    ModelInfo,
    RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
logger = logging.getLogger(EVENT_LOGGER_NAME)
[docs]
class ReplayChatCompletionClient(ChatCompletionClient):
    """
    A mock chat completion client that replays predefined responses using an index-based approach.
    This class simulates a chat completion client by replaying a predefined list of responses. It supports both single completion and streaming responses. The responses can be either strings or CreateResult objects. The client now uses an index-based approach to access the responses, allowing for resetting the state.
    .. note::
        The responses can be either strings or CreateResult objects.
    Args:
        chat_completions (Sequence[Union[str, CreateResult]]): A list of predefined responses to replay.
    Raises:
        ValueError("No more mock responses available"): If the list of provided outputs are exhausted.
    Examples:
    Simple chat completion client to return pre-defined responses.
        .. code-block:: python
            from autogen_ext.models.replay import ReplayChatCompletionClient
            from autogen_core.models import UserMessage
            async def example():
                chat_completions = [
                    "Hello, how can I assist you today?",
                    "I'm happy to help with any questions you have.",
                    "Is there anything else I can assist you with?",
                ]
                client = ReplayChatCompletionClient(chat_completions)
                messages = [UserMessage(content="What can you do?", source="user")]
                response = await client.create(messages)
                print(response.content)  # Output: "Hello, how can I assist you today?"
    Simple streaming chat completion client to return pre-defined responses
        .. code-block:: python
            import asyncio
            from autogen_ext.models.replay import ReplayChatCompletionClient
            from autogen_core.models import UserMessage
            async def example():
                chat_completions = [
                    "Hello, how can I assist you today?",
                    "I'm happy to help with any questions you have.",
                    "Is there anything else I can assist you with?",
                ]
                client = ReplayChatCompletionClient(chat_completions)
                messages = [UserMessage(content="What can you do?", source="user")]
                async for token in client.create_stream(messages):
                    print(token, end="")  # Output: "Hello, how can I assist you today?"
                async for token in client.create_stream(messages):
                    print(token, end="")  # Output: "I'm happy to help with any questions you have."
                asyncio.run(example())
    Using `.reset` to reset the chat client state
        .. code-block:: python
            import asyncio
            from autogen_ext.models.replay import ReplayChatCompletionClient
            from autogen_core.models import UserMessage
            async def example():
                chat_completions = [
                    "Hello, how can I assist you today?",
                ]
                client = ReplayChatCompletionClient(chat_completions)
                messages = [UserMessage(content="What can you do?", source="user")]
                response = await client.create(messages)
                print(response.content)  # Output: "Hello, how can I assist you today?"
                response = await client.create(messages)  # Raises ValueError("No more mock responses available")
                client.reset()  # Reset the client state (current index of message and token usages)
                response = await client.create(messages)
                print(response.content)  # Output: "Hello, how can I assist you today?" again
            asyncio.run(example())
    """
    __protocol__: ChatCompletionClient
    # TODO: Support FunctionCall in responses
    # TODO: Support logprobs in Responses
    # TODO: Support model capabilities
    def __init__(
        self,
        chat_completions: Sequence[Union[str, CreateResult]],
    ):
        self.chat_completions = list(chat_completions)
        self.provided_message_count = len(self.chat_completions)
        self._model_info = ModelInfo(
            vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
        )
        self._total_available_tokens = 10000
        self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        self._current_index = 0
[docs]
    async def create(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        json_output: Optional[bool] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> CreateResult:
        """Return the next completion from the list."""
        if self._current_index >= len(self.chat_completions):
            raise ValueError("No more mock responses available")
        response = self.chat_completions[self._current_index]
        _, prompt_token_count = self._tokenize(messages)
        if isinstance(response, str):
            _, output_token_count = self._tokenize(response)
            self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
            response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True)
        else:
            self._cur_usage = RequestUsage(
                prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens
            )
        self._update_total_usage()
        self._current_index += 1
        return response 
[docs]
    async def create_stream(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        json_output: Optional[bool] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> AsyncGenerator[Union[str, CreateResult], None]:
        """Return the next completion as a stream."""
        if self._current_index >= len(self.chat_completions):
            raise ValueError("No more mock responses available")
        response = self.chat_completions[self._current_index]
        _, prompt_token_count = self._tokenize(messages)
        if isinstance(response, str):
            output_tokens, output_token_count = self._tokenize(response)
            self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
            for i, token in enumerate(output_tokens):
                if i < len(output_tokens) - 1:
                    yield token + " "
                else:
                    yield token
            self._update_total_usage()
        else:
            self._cur_usage = RequestUsage(
                prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens
            )
            yield response
            self._update_total_usage()
        self._current_index += 1 
[docs]
    def actual_usage(self) -> RequestUsage:
        return self._cur_usage 
[docs]
    def total_usage(self) -> RequestUsage:
        return self._total_usage 
[docs]
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        _, token_count = self._tokenize(messages)
        return token_count 
[docs]
    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        return max(
            0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
        ) 
    def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]:
        total_tokens = 0
        all_tokens: List[str] = []
        if isinstance(messages, str):
            tokens = messages.split()
            total_tokens += len(tokens)
            all_tokens.extend(tokens)
        elif hasattr(messages, "content"):
            if isinstance(messages.content, str):  # type: ignore [reportAttributeAccessIssue]
                tokens = messages.content.split()  # type: ignore [reportAttributeAccessIssue]
                total_tokens += len(tokens)
                all_tokens.extend(tokens)
            else:
                logger.warning("Token count has been done only on string content", RuntimeWarning)
        elif isinstance(messages, Sequence):
            for message in messages:
                if isinstance(message.content, str):  # type: ignore [reportAttributeAccessIssue, union-attr]
                    tokens = message.content.split()  # type: ignore [reportAttributeAccessIssue, union-attr]
                    total_tokens += len(tokens)
                    all_tokens.extend(tokens)
                else:
                    logger.warning("Token count has been done only on string content", RuntimeWarning)
        return all_tokens, total_tokens
    def _update_total_usage(self) -> None:
        self._total_usage.completion_tokens += self._cur_usage.completion_tokens
        self._total_usage.prompt_tokens += self._cur_usage.prompt_tokens
    @property
    def capabilities(self) -> ModelCapabilities:  # type: ignore
        """Return mock capabilities."""
        warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
        return self._model_info
    @property
    def model_info(self) -> ModelInfo:
        return self._model_info
[docs]
    def reset(self) -> None:
        """Reset the client state and usage to its initial state."""
        self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        self._current_index = 0