Source code for autogen_ext.models.ollama._ollama_client

import asyncio
import inspect
import json
import logging
import math
import re
import warnings
from dataclasses import dataclass
from typing import (
    Any,
    AsyncGenerator,
    Dict,
    List,
    Literal,
    Mapping,
    Optional,
    Sequence,
    Union,
    cast,
)

import tiktoken
from autogen_core import (
    EVENT_LOGGER_NAME,
    TRACE_LOGGER_NAME,
    CancellationToken,
    Component,
    FunctionCall,
    Image,
)
from autogen_core.logging import LLMCallEvent, LLMStreamEndEvent, LLMStreamStartEvent
from autogen_core.models import (
    AssistantMessage,
    ChatCompletionClient,
    CreateResult,
    FinishReasons,
    FunctionExecutionResultMessage,
    LLMMessage,
    ModelCapabilities,  # type: ignore
    ModelFamily,
    ModelInfo,
    RequestUsage,
    SystemMessage,
    UserMessage,
)
from autogen_core.tools import Tool, ToolSchema
from ollama import AsyncClient, ChatResponse, Message
from ollama import Image as OllamaImage
from ollama import Tool as OllamaTool
from ollama._types import ChatRequest
from pydantic import BaseModel
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Self, Unpack

from . import _model_info
from .config import BaseOllamaClientConfiguration, BaseOllamaClientConfigurationConfigModel

logger = logging.getLogger(EVENT_LOGGER_NAME)
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)

# TODO: support more kwargs. Can't automate the list like we can with openai or azure because ollama uses an untyped kwargs blob for initialization.
ollama_init_kwargs = set(["host"])

# TODO: add kwarg checking logic later
# create_kwargs = set(completion_create_params.CompletionCreateParamsBase.__annotations__.keys()) | set(
#     ("timeout", "stream")
# )
# # Only single choice allowed
# disallowed_create_args = set(["stream", "messages", "function_call", "functions", "n"])
# required_create_args: Set[str] = set(["model"])


def _ollama_client_from_config(config: Mapping[str, Any]) -> AsyncClient:
    # Take a copy
    copied_config = dict(config).copy()
    # Shave down the config to just the AzureOpenAIChatCompletionClient kwargs
    ollama_config = {k: v for k, v in copied_config.items() if k in ollama_init_kwargs}
    return AsyncClient(**ollama_config)


ollama_chat_request_fields: dict[str, Any] = [m for m in inspect.getmembers(ChatRequest) if m[0] == "model_fields"][0][
    1
]
OLLAMA_VALID_CREATE_KWARGS_KEYS = set(ollama_chat_request_fields.keys()) | set(
    ("model", "messages", "tools", "stream", "format", "options", "keep_alive", "response_format")
)
# NOTE: "response_format" is a special case that we handle for backwards compatibility.
# It is going to be deprecated in the future.


def _create_args_from_config(config: Mapping[str, Any]) -> Dict[str, Any]:
    if "response_format" in config:
        warnings.warn(
            "Using response_format will be deprecated. Use json_output instead.",
            DeprecationWarning,
            stacklevel=2,
        )
    create_args = {k.lower(): v for k, v in config.items() if k.lower() in OLLAMA_VALID_CREATE_KWARGS_KEYS}
    dropped_keys = [k for k in config.keys() if k.lower() not in OLLAMA_VALID_CREATE_KWARGS_KEYS]
    trace_logger.info(f"Dropped the following unrecognized keys from create_args: {dropped_keys}")

    return create_args
    # create_args = {k: v for k, v in config.items() if k in create_kwargs}
    # create_args_keys = set(create_args.keys())
    # if not required_create_args.issubset(create_args_keys):
    #     raise ValueError(f"Required create args are missing: {required_create_args - create_args_keys}")
    # if disallowed_create_args.intersection(create_args_keys):
    #     raise ValueError(f"Disallowed create args are present: {disallowed_create_args.intersection(create_args_keys)}")
    # return create_args


# TODO check types
# oai_system_message_schema = type2schema(ChatCompletionSystemMessageParam)
# oai_user_message_schema = type2schema(ChatCompletionUserMessageParam)
# oai_assistant_message_schema = type2schema(ChatCompletionAssistantMessageParam)
# oai_tool_message_schema = type2schema(ChatCompletionToolMessageParam)


def type_to_role(message: LLMMessage) -> str:  # return type: Message.role
    if isinstance(message, SystemMessage):
        return "system"
    elif isinstance(message, UserMessage):
        return "user"
    elif isinstance(message, AssistantMessage):
        return "assistant"
    else:
        return "tool"


def user_message_to_ollama(message: UserMessage) -> Sequence[Message]:
    assert_valid_name(message.source)
    if isinstance(message.content, str):
        return [
            Message(
                content=message.content,
                role="user",
                # name=message.source, # TODO: No name parameter in Ollama
            )
        ]
    else:
        ollama_messages: List[Message] = []
        for part in message.content:
            if isinstance(part, str):
                ollama_messages.append(Message(content=part, role="user"))
            elif isinstance(part, Image):
                # TODO: should images go into their own message? Should each image get its own message?
                if not ollama_messages:
                    ollama_messages.append(Message(role="user", images=[OllamaImage(value=part.to_base64())]))
                else:
                    if ollama_messages[-1].images is None:
                        ollama_messages[-1].images = [OllamaImage(value=part.to_base64())]
                    else:
                        ollama_messages[-1].images.append(OllamaImage(value=part.to_base64()))  # type: ignore
            else:
                raise ValueError(f"Unknown content type: {part}")
        return ollama_messages


def system_message_to_ollama(message: SystemMessage) -> Message:
    return Message(
        content=message.content,
        role="system",
    )


def _func_args_to_ollama_args(args: str) -> Dict[str, Any]:
    return json.loads(args)  # type: ignore


def func_call_to_ollama(message: FunctionCall) -> Message.ToolCall:
    return Message.ToolCall(
        function=Message.ToolCall.Function(
            name=message.name,
            arguments=_func_args_to_ollama_args(message.arguments),
        )
    )


def tool_message_to_ollama(
    message: FunctionExecutionResultMessage,
) -> Sequence[Message]:
    return [Message(content=x.content, role="tool") for x in message.content]


def assistant_message_to_ollama(
    message: AssistantMessage,
) -> Message:
    assert_valid_name(message.source)
    if isinstance(message.content, list):
        return Message(
            tool_calls=[func_call_to_ollama(x) for x in message.content],
            role="assistant",
            # name=message.source,
        )
    else:
        return Message(
            content=message.content,
            role="assistant",
        )


def to_ollama_type(message: LLMMessage) -> Sequence[Message]:
    if isinstance(message, SystemMessage):
        return [system_message_to_ollama(message)]
    elif isinstance(message, UserMessage):
        return user_message_to_ollama(message)
    elif isinstance(message, AssistantMessage):
        return [assistant_message_to_ollama(message)]
    else:
        return tool_message_to_ollama(message)


# TODO: Is this correct? Do we need this?
def calculate_vision_tokens(image: Image, detail: str = "auto") -> int:
    MAX_LONG_EDGE = 2048
    BASE_TOKEN_COUNT = 85
    TOKENS_PER_TILE = 170
    MAX_SHORT_EDGE = 768
    TILE_SIZE = 512

    if detail == "low":
        return BASE_TOKEN_COUNT

    width, height = image.image.size

    # Scale down to fit within a MAX_LONG_EDGE x MAX_LONG_EDGE square if necessary

    if width > MAX_LONG_EDGE or height > MAX_LONG_EDGE:
        aspect_ratio = width / height
        if aspect_ratio > 1:
            # Width is greater than height
            width = MAX_LONG_EDGE
            height = int(MAX_LONG_EDGE / aspect_ratio)
        else:
            # Height is greater than or equal to width
            height = MAX_LONG_EDGE
            width = int(MAX_LONG_EDGE * aspect_ratio)

    # Resize such that the shortest side is MAX_SHORT_EDGE if both dimensions exceed MAX_SHORT_EDGE
    aspect_ratio = width / height
    if width > MAX_SHORT_EDGE and height > MAX_SHORT_EDGE:
        if aspect_ratio > 1:
            # Width is greater than height
            height = MAX_SHORT_EDGE
            width = int(MAX_SHORT_EDGE * aspect_ratio)
        else:
            # Height is greater than or equal to width
            width = MAX_SHORT_EDGE
            height = int(MAX_SHORT_EDGE / aspect_ratio)

    # Calculate the number of tiles based on TILE_SIZE

    tiles_width = math.ceil(width / TILE_SIZE)
    tiles_height = math.ceil(height / TILE_SIZE)
    total_tiles = tiles_width * tiles_height
    # Calculate the total tokens based on the number of tiles and the base token count

    total_tokens = BASE_TOKEN_COUNT + TOKENS_PER_TILE * total_tiles

    return total_tokens


def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
    return RequestUsage(
        prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
        completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
    )


# Ollama's tools follow a stricter protocol than OAI or us. While OAI accepts a map of [str, Any], Ollama requires a map of [str, Property] where Property is a typed object containing a type and description. Therefore, only the keys "type" and "description" will be converted from the properties blob in the tool schema
def convert_tools(
    tools: Sequence[Tool | ToolSchema],
) -> List[OllamaTool]:
    result: List[OllamaTool] = []
    for tool in tools:
        if isinstance(tool, Tool):
            tool_schema = tool.schema
        else:
            assert isinstance(tool, dict)
            tool_schema = tool
        parameters = tool_schema["parameters"] if "parameters" in tool_schema else None
        ollama_properties: Mapping[str, OllamaTool.Function.Parameters.Property] | None = None
        if parameters is not None:
            ollama_properties = {}
            for prop_name, prop_schema in parameters["properties"].items():
                ollama_properties[prop_name] = OllamaTool.Function.Parameters.Property(
                    type=prop_schema["type"],
                    description=prop_schema["description"] if "description" in prop_schema else None,
                )
        result.append(
            OllamaTool(
                function=OllamaTool.Function(
                    name=tool_schema["name"],
                    description=tool_schema["description"] if "description" in tool_schema else "",
                    parameters=OllamaTool.Function.Parameters(
                        required=parameters["required"]
                        if parameters is not None and "required" in parameters
                        else None,
                        properties=ollama_properties,
                    ),
                ),
            )
        )
    # Check if all tools have valid names.
    for tool_param in result:
        assert_valid_name(tool_param["function"]["name"])
    return result


def normalize_name(name: str) -> str:
    """
    LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".

    Prefer _assert_valid_name for validating user configuration or input
    """
    return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]


def assert_valid_name(name: str) -> str:
    """
    Ensure that configured names are valid, raises ValueError if not.

    For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
    """
    if not re.match(r"^[a-zA-Z0-9_-]+$", name):
        raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
    if len(name) > 64:
        raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
    return name


# TODO: Does this need to change?
def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
    if stop_reason is None:
        return "unknown"

    # Convert to lower case
    stop_reason = stop_reason.lower()

    KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
        "stop": "stop",
        "end_turn": "stop",
        "tool_calls": "function_calls",
    }

    return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")


@dataclass
class CreateParams:
    messages: Sequence[Message]
    tools: Sequence[OllamaTool]
    format: Optional[Union[Literal["", "json"], JsonSchemaValue]]
    create_args: Dict[str, Any]


class BaseOllamaChatCompletionClient(ChatCompletionClient):
    def __init__(
        self,
        client: AsyncClient,
        *,
        create_args: Dict[str, Any],
        model_capabilities: Optional[ModelCapabilities] = None,  # type: ignore
        model_info: Optional[ModelInfo] = None,
    ):
        self._client = client
        self._model_name = create_args["model"]
        if model_capabilities is None and model_info is None:
            try:
                self._model_info = _model_info.get_info(create_args["model"])
            except KeyError as err:
                raise ValueError("model_info is required when model name is not a valid OpenAI model") from err
        elif model_capabilities is not None and model_info is not None:
            raise ValueError("model_capabilities and model_info are mutually exclusive")
        elif model_capabilities is not None and model_info is None:
            warnings.warn("model_capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
            info = cast(ModelInfo, model_capabilities)
            info["family"] = ModelFamily.UNKNOWN
            self._model_info = info
        elif model_capabilities is None and model_info is not None:
            self._model_info = model_info

        self._resolved_model: Optional[str] = None
        self._model_class: Optional[str] = None
        if "model" in create_args:
            self._resolved_model = create_args["model"]
            self._model_class = _model_info.resolve_model_class(create_args["model"])

        if (
            not self._model_info["json_output"]
            and "response_format" in create_args
            and (
                isinstance(create_args["response_format"], dict)
                and create_args["response_format"]["type"] == "json_object"
            )
        ):
            raise ValueError("Model does not support JSON output.")

        self._create_args = create_args
        self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
        # Ollama doesn't have IDs for tools, so we just increment a counter
        self._tool_id = 0

    @classmethod
    def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
        return OllamaChatCompletionClient(**config)

    def get_create_args(self) -> Mapping[str, Any]:
        return self._create_args

    def _process_create_args(
        self,
        messages: Sequence[LLMMessage],
        tools: Sequence[Tool | ToolSchema],
        json_output: Optional[bool | type[BaseModel]],
        extra_create_args: Mapping[str, Any],
    ) -> CreateParams:
        # Copy the create args and overwrite anything in extra_create_args
        create_args = self._create_args.copy()
        create_args.update(extra_create_args)
        create_args = _create_args_from_config(create_args)

        response_format_value: JsonSchemaValue | Literal["json"] | None = None

        if "response_format" in create_args:
            warnings.warn(
                "Using response_format will be deprecated. Use json_output instead.",
                DeprecationWarning,
                stacklevel=2,
            )
            value = create_args["response_format"]
            if isinstance(value, type) and issubclass(value, BaseModel):
                response_format_value = value.model_json_schema()
                # Remove response_format from create_args to prevent passing it twice.
                del create_args["response_format"]
            else:
                raise ValueError(f"response_format must be a Pydantic model class, not {type(value)}")

        if json_output is not None:
            if self.model_info["json_output"] is False and json_output is True:
                raise ValueError("Model does not support JSON output.")
            if json_output is True:
                # JSON mode.
                response_format_value = "json"
            elif json_output is False:
                # Text mode.
                response_format_value = None
            elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
                if response_format_value is not None:
                    raise ValueError(
                        "response_format and json_output cannot be set to a Pydantic model class at the same time. "
                        "Use json_output instead."
                    )
                # Beta client mode with Pydantic model class.
                response_format_value = json_output.model_json_schema()
            else:
                raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")

        if "format" in create_args:
            # Handle the case where format is set from create_args.
            if json_output is not None:
                raise ValueError("json_output and format cannot be set at the same time. Use json_output instead.")
            assert response_format_value is None
            response_format_value = create_args["format"]
            # Remove format from create_args to prevent passing it twice.
            del create_args["format"]

        # TODO: allow custom handling.
        # For now we raise an error if images are present and vision is not supported
        if self.model_info["vision"] is False:
            for message in messages:
                if isinstance(message, UserMessage):
                    if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
                        raise ValueError("Model does not support vision and image was provided")

        if self.model_info["json_output"] is False and json_output is True:
            raise ValueError("Model does not support JSON output.")

        ollama_messages_nested = [to_ollama_type(m) for m in messages]
        ollama_messages = [item for sublist in ollama_messages_nested for item in sublist]

        if self.model_info["function_calling"] is False and len(tools) > 0:
            raise ValueError("Model does not support function calling and tools were provided")

        converted_tools = convert_tools(tools)

        return CreateParams(
            messages=ollama_messages,
            tools=converted_tools,
            format=response_format_value,
            create_args=create_args,
        )

    async def create(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        json_output: Optional[bool | type[BaseModel]] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> CreateResult:
        # Make sure all extra_create_args are valid
        # TODO: kwarg checking logic
        # extra_create_args_keys = set(extra_create_args.keys())
        # if not create_kwargs.issuperset(extra_create_args_keys):
        #     raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
        create_params = self._process_create_args(
            messages,
            tools,
            json_output,
            extra_create_args,
        )
        future = asyncio.ensure_future(
            self._client.chat(  # type: ignore
                # model=self._model_name,
                messages=create_params.messages,
                tools=create_params.tools if len(create_params.tools) > 0 else None,
                stream=False,
                format=create_params.format,
                **create_params.create_args,
            )
        )
        if cancellation_token is not None:
            cancellation_token.link_future(future)
        result: ChatResponse = await future

        usage = RequestUsage(
            # TODO backup token counting
            prompt_tokens=result.prompt_eval_count if result.prompt_eval_count is not None else 0,
            completion_tokens=(result.eval_count if result.eval_count is not None else 0),
        )

        logger.info(
            LLMCallEvent(
                messages=[m.model_dump() for m in create_params.messages],
                response=result.model_dump(),
                prompt_tokens=usage.prompt_tokens,
                completion_tokens=usage.completion_tokens,
            )
        )

        if self._resolved_model is not None:
            if self._resolved_model != result.model:
                warnings.warn(
                    f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
                    "Model mapping in autogen_ext.models.openai may be incorrect.",
                    stacklevel=2,
                )

        # Detect whether it is a function call or not.
        # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
        content: Union[str, List[FunctionCall]]
        if result.message.tool_calls is not None:
            # TODO: What are possible values for done_reason?
            if result.done_reason != "tool_calls":
                warnings.warn(
                    f"Finish reason mismatch: {result.done_reason} != tool_calls "
                    "when tool_calls are present. Finish reason may not be accurate. "
                    "This may be due to the API used that is not returning the correct finish reason.",
                    stacklevel=2,
                )
            # TODO: Is this still an error condition?
            if result.message.content is not None and result.message.content != "":
                warnings.warn(
                    "Both tool_calls and content are present in the message. "
                    "This is unexpected. content will be ignored, tool_calls will be used.",
                    stacklevel=2,
                )
            # NOTE: If OAI response type changes, this will need to be updated
            content = [
                FunctionCall(
                    id=str(self._tool_id),
                    arguments=json.dumps(x.function.arguments),
                    name=normalize_name(x.function.name),
                )
                for x in result.message.tool_calls
            ]
            finish_reason = "tool_calls"
            self._tool_id += 1
        else:
            finish_reason = result.done_reason or ""
            content = result.message.content or ""

        # Ollama currently doesn't provide these.
        # Currently open ticket: https://github.com/ollama/ollama/issues/2415
        # logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
        # if choice.logprobs and choice.logprobs.content:
        #     logprobs = [
        #         ChatCompletionTokenLogprob(
        #             token=x.token,
        #             logprob=x.logprob,
        #             top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
        #             bytes=x.bytes,
        #         )
        #         for x in choice.logprobs.content
        #     ]
        response = CreateResult(
            finish_reason=normalize_stop_reason(finish_reason),
            content=content,
            usage=usage,
            cached=False,
            logprobs=None,
        )

        self._total_usage = _add_usage(self._total_usage, usage)
        self._actual_usage = _add_usage(self._actual_usage, usage)

        return response

    async def create_stream(
        self,
        messages: Sequence[LLMMessage],
        *,
        tools: Sequence[Tool | ToolSchema] = [],
        json_output: Optional[bool | type[BaseModel]] = None,
        extra_create_args: Mapping[str, Any] = {},
        cancellation_token: Optional[CancellationToken] = None,
    ) -> AsyncGenerator[Union[str, CreateResult], None]:
        # Make sure all extra_create_args are valid
        # TODO: kwarg checking logic
        # extra_create_args_keys = set(extra_create_args.keys())
        # if not create_kwargs.issuperset(extra_create_args_keys):
        #     raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
        create_params = self._process_create_args(
            messages,
            tools,
            json_output,
            extra_create_args,
        )
        stream_future = asyncio.ensure_future(
            self._client.chat(  # type: ignore
                # model=self._model_name,
                messages=create_params.messages,
                tools=create_params.tools if len(create_params.tools) > 0 else None,
                stream=True,
                format=create_params.format,
                **create_params.create_args,
            )
        )
        if cancellation_token is not None:
            cancellation_token.link_future(stream_future)
        stream = await stream_future

        chunk = None
        stop_reason = None
        content_chunks: List[str] = []
        full_tool_calls: List[FunctionCall] = []
        completion_tokens = 0
        first_chunk = True
        while True:
            try:
                chunk_future = asyncio.ensure_future(anext(stream))
                if cancellation_token is not None:
                    cancellation_token.link_future(chunk_future)
                chunk = await chunk_future

                if first_chunk:
                    first_chunk = False
                    # Emit the start event.
                    logger.info(
                        LLMStreamStartEvent(
                            messages=[m.model_dump() for m in create_params.messages],
                        )
                    )
                # set the stop_reason for the usage chunk to the prior stop_reason
                stop_reason = chunk.done_reason if chunk.done and stop_reason is None else stop_reason
                # First try get content
                if chunk.message.content is not None:
                    content_chunks.append(chunk.message.content)
                    if len(chunk.message.content) > 0:
                        yield chunk.message.content
                    continue

                # Otherwise, get tool calls
                if chunk.message.tool_calls is not None:
                    full_tool_calls.extend(
                        [
                            FunctionCall(
                                id=str(self._tool_id),
                                arguments=json.dumps(x.function.arguments),
                                name=normalize_name(x.function.name),
                            )
                            for x in chunk.message.tool_calls
                        ]
                    )

                # TODO: logprobs currently unsupported in ollama.
                # See: https://github.com/ollama/ollama/issues/2415
                # if choice.logprobs and choice.logprobs.content:
                #     logprobs = [
                #         ChatCompletionTokenLogprob(
                #             token=x.token,
                #             logprob=x.logprob,
                #             top_logprobs=[TopLogprob(logprob=y.logprob, bytes=y.bytes) for y in x.top_logprobs],
                #             bytes=x.bytes,
                #         )
                #         for x in choice.logprobs.content
                #     ]

            except StopAsyncIteration:
                break

        if chunk and chunk.prompt_eval_count:
            prompt_tokens = chunk.prompt_eval_count
        else:
            prompt_tokens = 0

        if stop_reason == "function_call":
            raise ValueError("Function calls are not supported in this context")

        content: Union[str, List[FunctionCall]]
        if len(content_chunks) > 1:
            content = "".join(content_chunks)
            if chunk and chunk.eval_count:
                completion_tokens = chunk.eval_count
            else:
                completion_tokens = 0
        else:
            completion_tokens = 0
            # TODO: fix assumption that dict values were added in order and actually order by int index
            # for tool_call in full_tool_calls.values():
            #     # value = json.dumps(tool_call)
            #     # completion_tokens += count_token(value, model=model)
            #     completion_tokens += 0
            content = full_tool_calls

        usage = RequestUsage(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
        )

        result = CreateResult(
            finish_reason=normalize_stop_reason(stop_reason),
            content=content,
            usage=usage,
            cached=False,
            logprobs=None,
        )

        # Emit the end event.
        logger.info(
            LLMStreamEndEvent(
                response=result.model_dump(),
                prompt_tokens=usage.prompt_tokens,
                completion_tokens=usage.completion_tokens,
            )
        )

        self._total_usage = _add_usage(self._total_usage, usage)
        self._actual_usage = _add_usage(self._actual_usage, usage)

        yield result

    async def close(self) -> None:
        pass  # ollama has no close method?

    def actual_usage(self) -> RequestUsage:
        return self._actual_usage

    def total_usage(self) -> RequestUsage:
        return self._total_usage

    # TODO: probably needs work
    def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        model = self._create_args["model"]
        try:
            encoding = tiktoken.encoding_for_model(model)
        except KeyError:
            trace_logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
            encoding = tiktoken.get_encoding("cl100k_base")
        tokens_per_message = 3
        num_tokens = 0

        # Message tokens.
        for message in messages:
            num_tokens += tokens_per_message
            ollama_message = to_ollama_type(message)
            for ollama_message_part in ollama_message:
                if isinstance(message.content, Image):
                    num_tokens += calculate_vision_tokens(message.content)
                elif ollama_message_part.content is not None:
                    num_tokens += len(encoding.encode(ollama_message_part.content))
        # TODO: every model family has its own message sequence.
        num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>

        # Tool tokens.
        ollama_tools = convert_tools(tools)
        for tool in ollama_tools:
            function = tool["function"]
            tool_tokens = len(encoding.encode(function["name"]))
            if "description" in function:
                tool_tokens += len(encoding.encode(function["description"]))
            tool_tokens -= 2
            if "parameters" in function:
                parameters = function["parameters"]
                if "properties" in parameters:
                    assert isinstance(parameters["properties"], dict)
                    for propertiesKey in parameters["properties"]:  # pyright: ignore
                        assert isinstance(propertiesKey, str)
                        tool_tokens += len(encoding.encode(propertiesKey))
                        v = parameters["properties"][propertiesKey]  # pyright: ignore
                        for field in v:  # pyright: ignore
                            if field == "type":
                                tool_tokens += 2
                                tool_tokens += len(encoding.encode(v["type"]))  # pyright: ignore
                            elif field == "description":
                                tool_tokens += 2
                                tool_tokens += len(encoding.encode(v["description"]))  # pyright: ignore
                            elif field == "enum":
                                tool_tokens -= 3
                                for o in v["enum"]:  # pyright: ignore
                                    tool_tokens += 3
                                    tool_tokens += len(encoding.encode(o))  # pyright: ignore
                            else:
                                trace_logger.warning(f"Not supported field {field}")
                    tool_tokens += 11
                    if len(parameters["properties"]) == 0:  # pyright: ignore
                        tool_tokens -= 2
            num_tokens += tool_tokens
        num_tokens += 12
        return num_tokens

    def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
        token_limit = _model_info.get_token_limit(self._create_args["model"])
        return token_limit - self.count_tokens(messages, tools=tools)

    @property
    def capabilities(self) -> ModelCapabilities:  # type: ignore
        warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
        return self._model_info

    @property
    def model_info(self) -> ModelInfo:
        return self._model_info


# TODO: see if response_format can just be a json blob instead of a BaseModel
[docs] class OllamaChatCompletionClient(BaseOllamaChatCompletionClient, Component[BaseOllamaClientConfigurationConfigModel]): """Chat completion client for Ollama hosted models. Ollama must be installed and the appropriate model pulled. Args: model (str): Which Ollama model to use. host (optional, str): Model host url. response_format (optional, pydantic.BaseModel): The format of the response. If provided, the response will be parsed into this format as json. options (optional, Mapping[str, Any] | Options): Additional options to pass to the Ollama client. model_info (optional, ModelInfo): The capabilities of the model. **Required if the model is not listed in the ollama model info.** Note: Only models with 200k+ downloads (as of Jan 21, 2025), + phi4, deepseek-r1 have pre-defined model infos. See `this file <https://github.com/microsoft/autogen/blob/main/python/packages/autogen-ext/src/autogen_ext/models/ollama/_model_info.py>`__ for the full list. An entry for one model encompases all parameter variants of that model. To use this client, you must install the `ollama` extension: .. code-block:: bash pip install "autogen-ext[ollama]" The following code snippet shows how to use the client with an Ollama model: .. code-block:: python from autogen_ext.models.ollama import OllamaChatCompletionClient from autogen_core.models import UserMessage ollama_client = OllamaChatCompletionClient( model="llama3", ) result = await ollama_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore print(result) To load the client from a configuration, you can use the `load_component` method: .. code-block:: python from autogen_core.models import ChatCompletionClient config = { "provider": "OllamaChatCompletionClient", "config": {"model": "llama3"}, } client = ChatCompletionClient.load_component(config) To output structured data, you can use the `response_format` argument: .. code-block:: python from autogen_ext.models.ollama import OllamaChatCompletionClient from autogen_core.models import UserMessage from pydantic import BaseModel class StructuredOutput(BaseModel): first_name: str last_name: str ollama_client = OllamaChatCompletionClient( model="llama3", response_format=StructuredOutput, ) result = await ollama_client.create([UserMessage(content="Who was the first man on the moon?", source="user")]) # type: ignore print(result) Note: Tool usage in ollama is stricter than in its OpenAI counterparts. While OpenAI accepts a map of [str, Any], Ollama requires a map of [str, Property] where Property is a typed object containing ``type`` and ``description`` fields. Therefore, only the keys ``type`` and ``description`` will be converted from the properties blob in the tool schema. To view the full list of available configuration options, see the :py:class:`OllamaClientConfigurationConfigModel` class. """ component_type = "model" component_config_schema = BaseOllamaClientConfigurationConfigModel component_provider_override = "autogen_ext.models.ollama.OllamaChatCompletionClient" def __init__(self, **kwargs: Unpack[BaseOllamaClientConfiguration]): if "model" not in kwargs: raise ValueError("model is required for OllamaChatCompletionClient") model_capabilities: Optional[ModelCapabilities] = None # type: ignore copied_args = dict(kwargs).copy() if "model_capabilities" in kwargs: model_capabilities = kwargs["model_capabilities"] del copied_args["model_capabilities"] model_info: Optional[ModelInfo] = None if "model_info" in kwargs: model_info = kwargs["model_info"] del copied_args["model_info"] client = _ollama_client_from_config(copied_args) create_args = _create_args_from_config(copied_args) self._raw_config: Dict[str, Any] = copied_args super().__init__( client=client, create_args=create_args, model_capabilities=model_capabilities, model_info=model_info ) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state["_client"] = None return state def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update(state) self._client = _ollama_client_from_config(state["_raw_config"])
[docs] def _to_config(self) -> BaseOllamaClientConfigurationConfigModel: copied_config = self._raw_config.copy() return BaseOllamaClientConfigurationConfigModel(**copied_config)
[docs] @classmethod def _from_config(cls, config: BaseOllamaClientConfigurationConfigModel) -> Self: copied_config = config.model_copy().model_dump(exclude_none=True) return cls(**copied_config)