Source code for autogen_ext.tools.mcp._host._sampling

import base64
import io
from abc import ABC, abstractmethod
from typing import Any, Dict

from autogen_core import Image
from autogen_core._component_config import Component, ComponentBase, ComponentModel
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    FinishReasons,
    LLMMessage,
    ModelInfo,
    SystemMessage,
    UserMessage,
)
from PIL import Image as PILImage
from pydantic import BaseModel

from mcp import types as mcp_types
from mcp.types import StopReason


def parse_sampling_content(
    content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
    model_info: ModelInfo | None = None,
) -> str | Image:
    """Convert MCP content types to AutoGen content types.

    Handles text and image content conversion, with vision model validation for images.

    Args:
        content: MCP content object (text, image, or audio)
        model_info: Optional model information for vision capability checking

    Returns:
        Converted content as string or Image object

    Raises:
        RuntimeError: If image content is provided but model doesn't support vision
        ValueError: If content type is unsupported
    """
    if content.type == "text":
        return content.text
    elif content.type == "image":
        if model_info and not model_info.get("vision", False):
            model_family = model_info.get("family", "unknown")
            raise RuntimeError(f"model {model_family} does not support vision.")

        # Decode base64 image data and create PIL Image
        image_data = base64.b64decode(content.data)
        pil_image = PILImage.open(io.BytesIO(image_data))
        return Image.from_pil(pil_image)
    else:
        raise ValueError(f"Unsupported content type: {content.type}")


def parse_sampling_message(message: mcp_types.SamplingMessage, model_info: ModelInfo | None = None) -> LLMMessage:
    """Convert MCP sampling messages to AutoGen LLM messages.

    Args:
        message: MCP sampling message with role and content
        model_info: Optional model information for content parsing

    Returns:
        Converted AutoGen LLM message (UserMessage or AssistantMessage)

    Raises:
        ValueError: If message role is not recognized
        AssertionError: If assistant message content is not text
    """
    content = parse_sampling_content(message.content, model_info=model_info)
    if message.role == "user":
        return UserMessage(
            source="user",
            content=[content],
        )
    elif message.role == "assistant":
        assert isinstance(content, str), "Assistant messages only support string content."
        return AssistantMessage(
            source="assistant",
            content=content,
        )
    else:
        raise ValueError(f"Unrecognized message role: {message.role}")


def finish_reason_to_stop_reason(finish_reason: FinishReasons) -> StopReason:
    """Convert AutoGen finish reasons to MCP stop reasons.

    Args:
        finish_reason: AutoGen completion finish reason

    Returns:
        Corresponding MCP stop reason
    """
    if finish_reason == "stop":
        return "endTurn"
    elif finish_reason == "length":
        return "maxTokens"
    else:
        return finish_reason


def create_request_params_to_extra_create_args(params: mcp_types.CreateMessageRequestParams) -> Dict[str, Any]:
    """Convert MCP request parameters to AutoGen extra create arguments.

    Args:
        params: MCP message creation request parameters

    Returns:
        Dictionary of extra arguments for AutoGen chat completion client
    """
    # TODO: Need to support all ChatCompletionClients
    extra_create_args: dict[str, Any] = {"max_tokens": params.maxTokens}
    if params.temperature is not None:
        extra_create_args["temperature"] = params.temperature
    if params.stopSequences is not None:
        extra_create_args["stop"] = params.stopSequences
    return extra_create_args


[docs] class Sampler(ABC, ComponentBase[BaseModel]): component_type = "mcp_sampler"
[docs] @abstractmethod async def sample(
self, params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: ...
[docs] class ChatCompletionClientSamplerConfig(BaseModel): client_config: ComponentModel | Dict[str, Any]
[docs] class ChatCompletionClientSampler(Sampler, Component[ChatCompletionClientSamplerConfig]): component_config_schema = ChatCompletionClientSamplerConfig component_provider_override = "autogen_ext.tools.mcp.ChatCompletionClientSampler" def __init__(self, model_client: ChatCompletionClient): self._model_client = model_client
[docs] async def sample( self, params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: # Convert MCP messages to AutoGen format using existing parser autogen_messages: list[LLMMessage] = [] # Add system prompt if provided if params.systemPrompt: autogen_messages.append(SystemMessage(content=params.systemPrompt)) # Parse sampling messages for msg in params.messages: autogen_messages.append(parse_sampling_message(msg, model_info=self._model_client.model_info)) # Use the model client to generate a response extra_create_args = create_request_params_to_extra_create_args(params) response = await self._model_client.create(messages=autogen_messages, extra_create_args=extra_create_args) # Extract text content from response if isinstance(response.content, str): response_text = response.content else: from pydantic_core import to_json # Handle function calls - convert to string representation response_text = to_json(response.content).decode() return mcp_types.CreateMessageResult( role="assistant", content=mcp_types.TextContent(type="text", text=response_text), model=self._model_client.model_info["family"], stopReason=finish_reason_to_stop_reason(response.finish_reason), )
[docs] def _to_config(self) -> BaseModel: return ChatCompletionClientSamplerConfig(client_config=self._model_client.dump_component())
[docs] @classmethod def _from_config(cls, config: ChatCompletionClientSamplerConfig) -> "ChatCompletionClientSampler": return ChatCompletionClientSampler(model_client=ChatCompletionClient.load_component(config.client_config))