Module tinytroupe.clients.ollama_client

Expand source code
import logging
import os
import pickle
import time

import requests

from tinytroupe import config_manager, utils

logger = logging.getLogger("tinytroupe")


class OllamaClient:
    """
    A client for interacting with the Ollama API using direct HTTP requests.
    """

    @config_manager.config_defaults(
        cache_api_calls="cache_api_calls", cache_file_name="cache_file_name"
    )
    def __init__(self, cache_api_calls=None, cache_file_name=None) -> None:
        logger.debug("Initializing OllamaClient")
        self.base_url = config_manager.get("base_url", "http://localhost:11434/v1")
        logger.debug(f"base_url set to {self.base_url}")

        # Set up caching
        self.cache_api_calls = cache_api_calls
        self.cache_file_name = cache_file_name
        if self.cache_api_calls:
            self.api_cache = self._load_cache()

    def set_api_cache(self, cache_api_calls, cache_file_name=None):
        """
        Enables or disables the caching of API calls.

        Args:
        cache_file_name (str): The name of the file to use for caching API calls.
        """
        self.cache_api_calls = cache_api_calls
        self.cache_file_name = cache_file_name
        if self.cache_api_calls:
            # load the cache, if any
            self.api_cache = self._load_cache()

    @config_manager.config_defaults(
        model="model",
        temperature="temperature",
        top_p="top_p",
        frequency_penalty="frequency_penalty",
        presence_penalty="presence_penalty",
        num_ctx="num_ctx",
        timeout="timeout",
        max_attempts="max_attempts",
        waiting_time="waiting_time",
        exponential_backoff_factor="exponential_backoff_factor",
        response_format=None,
        echo=None,
    )
    def send_message(
        self,
        current_messages,
        dedent_messages=True,
        model=None,
        temperature=None,
        max_completion_tokens=None,  # Ollama doesn't use max_completion_tokens
        top_p=None,
        frequency_penalty=None,
        presence_penalty=None,
        stop=None,
        num_ctx=None,
        timeout=None,
        max_attempts=None,
        waiting_time=None,
        exponential_backoff_factor=None,
        n=1,
        response_format=None,
        enable_pydantic_model_return=False,
        echo=False,
    ):
        """
        Sends a message to the Ollama API and returns the response.
        """

        from tinytroupe.clients import (  # avoid circular import
            InvalidRequestError,
            NonTerminalError,
        )

        def aux_exponential_backoff():
            nonlocal waiting_time
            logger.info(
                f"Request failed. Waiting {waiting_time} seconds between requests..."
            )
            time.sleep(waiting_time)
            waiting_time = waiting_time * exponential_backoff_factor

        # Prepare the API parameters
        chat_api_params = {
            "model": model,
            "messages": current_messages,
            "options": {
                "temperature": temperature,
                "top_p": top_p,
                "frequency_penalty": frequency_penalty,
                "presence_penalty": presence_penalty,
                "stop": stop,
                "num_ctx": num_ctx,  # special Ollama parameter for the input size
            },
            "stream": False,
            "n": n,
        }

        # remove any parameter that is None, so we use the API defaults
        chat_api_params = {k: v for k, v in chat_api_params.items() if v is not None}
        # ... within options too
        chat_api_params["options"] = {
            k: v for k, v in chat_api_params["options"].items() if v is not None
        }

        i = 0
        while i < max_attempts:
            try:
                i += 1

                start_time = time.monotonic()
                logger.debug(f"Sending request to Ollama API. Attempt {i}")

                # Check cache first
                cache_key = str((model, chat_api_params))
                if self.cache_api_calls and (cache_key in self.api_cache):
                    response = self.api_cache[cache_key]
                else:
                    logger.info(
                        f"Waiting {waiting_time} seconds before next API request..."
                    )
                    time.sleep(waiting_time)

                    # Make the API call
                    response = self._make_request(
                        "chat/completions",
                        method="POST",
                        json=chat_api_params,
                        timeout=timeout,
                    )

                    # Cache the response if caching is enabled
                    if self.cache_api_calls:
                        self.api_cache[cache_key] = response
                        self._save_cache()

                end_time = time.monotonic()
                logger.debug(
                    f"Got response in {end_time - start_time:.2f} seconds after {i} attempts"
                )

                # Extract and return the relevant part of the response
                return utils.sanitize_dict(self._extract_response(response))

            except requests.exceptions.RequestException as e:
                logger.error(f"[{i}] Request error: {e}")
                if "Invalid request" in str(e):
                    raise InvalidRequestError(str(e))
                aux_exponential_backoff()

            except Exception as e:
                logger.error(f"[{i}] Error: {e}")
                aux_exponential_backoff()

        logger.error(f"Failed to get response after {max_attempts} attempts")
        return None

    def _make_request(self, endpoint, method="POST", **kwargs):
        """
        Makes a request to the Ollama API.
        """
        url = f"{self.base_url}/{endpoint}"
        logger.debug(f"Making {method} request to {url}")
        logger.debug(f"Request parameters: {kwargs}")

        response = requests.request(method, url, **kwargs)
        response.raise_for_status()

        return response.json()

    def _extract_response(self, response):
        """
        Extracts the relevant information from the API response.
        """
        logger.debug(f"Extracting from response: {response}")
        try:
            return {
                "role": response["choices"][0]["message"]["role"],
                "content": response["choices"][0]["message"]["content"],
            }
        except (KeyError, IndexError) as e:
            logger.error(f"Error extracting response: {e}")
            logger.error(f"Response structure: {response}")
            raise ValueError("Invalid response format from Ollama")

    def _save_cache(self):
        """
        Saves the API cache to disk using pickle.
        """
        with open(self.cache_file_name, "wb") as f:
            pickle.dump(self.api_cache, f)

    def _load_cache(self):
        """
        Loads the API cache from disk.
        """
        if os.path.exists(self.cache_file_name):
            with open(self.cache_file_name, "rb") as f:
                return pickle.load(f)
        return {}

    def get_models(self):
        """
        Gets the list of available models from Ollama.
        """
        try:
            response = self._make_request("models", method="GET")
            return response.get("models", [])
        except Exception as e:
            logger.error(f"Error getting models: {e}")
            return []

    def _count_tokens(self, messages: list, model: str):
        """
        Count the number of tokens in a list of messages using Ollama's API.

        Args:
            messages (list): A list of dictionaries representing the conversation history.
            model (str): The name of the model to use for encoding the string.

        Returns:
            int or None: The number of tokens in the messages, or None if an error occurs.
        """
        try:
            # Combine all message content into a single string
            combined_text = ""
            for message in messages:
                # Add role/name if present
                if "name" in message:
                    combined_text += f"{message['name']}: "
                if "role" in message:
                    combined_text += f"{message['role']}: "
                # Add message content
                if "content" in message:
                    combined_text += f"{message['content']}\n"

            # Prepare the request payload
            payload = {
                "model": model,
                "input": combined_text,
                "options": {
                    "temperature": 0  # Set to 0 since we only care about token count
                },
            }

            # Make the request to Ollama's API
            temp_url = self.base_url.replace(
                "/v1", ""
            )  # Not sure what happened in their API, complete hack
            response = requests.post(f"{temp_url}/api/embed", json=payload)
            response.raise_for_status()

            # Extract token count from response
            data = response.json()
            token_count = data.get("prompt_eval_count", 0)

            return token_count

        except requests.exceptions.RequestException as e:
            logger.error(f"Error making request to Ollama API: {e}")
            return None
        except Exception as e:
            logger.error(f"Error counting tokens: {e}")
            return None

Classes

class OllamaClient (cache_api_calls=None, cache_file_name=None)

A client for interacting with the Ollama API using direct HTTP requests.

Expand source code
class OllamaClient:
    """
    A client for interacting with the Ollama API using direct HTTP requests.
    """

    @config_manager.config_defaults(
        cache_api_calls="cache_api_calls", cache_file_name="cache_file_name"
    )
    def __init__(self, cache_api_calls=None, cache_file_name=None) -> None:
        logger.debug("Initializing OllamaClient")
        self.base_url = config_manager.get("base_url", "http://localhost:11434/v1")
        logger.debug(f"base_url set to {self.base_url}")

        # Set up caching
        self.cache_api_calls = cache_api_calls
        self.cache_file_name = cache_file_name
        if self.cache_api_calls:
            self.api_cache = self._load_cache()

    def set_api_cache(self, cache_api_calls, cache_file_name=None):
        """
        Enables or disables the caching of API calls.

        Args:
        cache_file_name (str): The name of the file to use for caching API calls.
        """
        self.cache_api_calls = cache_api_calls
        self.cache_file_name = cache_file_name
        if self.cache_api_calls:
            # load the cache, if any
            self.api_cache = self._load_cache()

    @config_manager.config_defaults(
        model="model",
        temperature="temperature",
        top_p="top_p",
        frequency_penalty="frequency_penalty",
        presence_penalty="presence_penalty",
        num_ctx="num_ctx",
        timeout="timeout",
        max_attempts="max_attempts",
        waiting_time="waiting_time",
        exponential_backoff_factor="exponential_backoff_factor",
        response_format=None,
        echo=None,
    )
    def send_message(
        self,
        current_messages,
        dedent_messages=True,
        model=None,
        temperature=None,
        max_completion_tokens=None,  # Ollama doesn't use max_completion_tokens
        top_p=None,
        frequency_penalty=None,
        presence_penalty=None,
        stop=None,
        num_ctx=None,
        timeout=None,
        max_attempts=None,
        waiting_time=None,
        exponential_backoff_factor=None,
        n=1,
        response_format=None,
        enable_pydantic_model_return=False,
        echo=False,
    ):
        """
        Sends a message to the Ollama API and returns the response.
        """

        from tinytroupe.clients import (  # avoid circular import
            InvalidRequestError,
            NonTerminalError,
        )

        def aux_exponential_backoff():
            nonlocal waiting_time
            logger.info(
                f"Request failed. Waiting {waiting_time} seconds between requests..."
            )
            time.sleep(waiting_time)
            waiting_time = waiting_time * exponential_backoff_factor

        # Prepare the API parameters
        chat_api_params = {
            "model": model,
            "messages": current_messages,
            "options": {
                "temperature": temperature,
                "top_p": top_p,
                "frequency_penalty": frequency_penalty,
                "presence_penalty": presence_penalty,
                "stop": stop,
                "num_ctx": num_ctx,  # special Ollama parameter for the input size
            },
            "stream": False,
            "n": n,
        }

        # remove any parameter that is None, so we use the API defaults
        chat_api_params = {k: v for k, v in chat_api_params.items() if v is not None}
        # ... within options too
        chat_api_params["options"] = {
            k: v for k, v in chat_api_params["options"].items() if v is not None
        }

        i = 0
        while i < max_attempts:
            try:
                i += 1

                start_time = time.monotonic()
                logger.debug(f"Sending request to Ollama API. Attempt {i}")

                # Check cache first
                cache_key = str((model, chat_api_params))
                if self.cache_api_calls and (cache_key in self.api_cache):
                    response = self.api_cache[cache_key]
                else:
                    logger.info(
                        f"Waiting {waiting_time} seconds before next API request..."
                    )
                    time.sleep(waiting_time)

                    # Make the API call
                    response = self._make_request(
                        "chat/completions",
                        method="POST",
                        json=chat_api_params,
                        timeout=timeout,
                    )

                    # Cache the response if caching is enabled
                    if self.cache_api_calls:
                        self.api_cache[cache_key] = response
                        self._save_cache()

                end_time = time.monotonic()
                logger.debug(
                    f"Got response in {end_time - start_time:.2f} seconds after {i} attempts"
                )

                # Extract and return the relevant part of the response
                return utils.sanitize_dict(self._extract_response(response))

            except requests.exceptions.RequestException as e:
                logger.error(f"[{i}] Request error: {e}")
                if "Invalid request" in str(e):
                    raise InvalidRequestError(str(e))
                aux_exponential_backoff()

            except Exception as e:
                logger.error(f"[{i}] Error: {e}")
                aux_exponential_backoff()

        logger.error(f"Failed to get response after {max_attempts} attempts")
        return None

    def _make_request(self, endpoint, method="POST", **kwargs):
        """
        Makes a request to the Ollama API.
        """
        url = f"{self.base_url}/{endpoint}"
        logger.debug(f"Making {method} request to {url}")
        logger.debug(f"Request parameters: {kwargs}")

        response = requests.request(method, url, **kwargs)
        response.raise_for_status()

        return response.json()

    def _extract_response(self, response):
        """
        Extracts the relevant information from the API response.
        """
        logger.debug(f"Extracting from response: {response}")
        try:
            return {
                "role": response["choices"][0]["message"]["role"],
                "content": response["choices"][0]["message"]["content"],
            }
        except (KeyError, IndexError) as e:
            logger.error(f"Error extracting response: {e}")
            logger.error(f"Response structure: {response}")
            raise ValueError("Invalid response format from Ollama")

    def _save_cache(self):
        """
        Saves the API cache to disk using pickle.
        """
        with open(self.cache_file_name, "wb") as f:
            pickle.dump(self.api_cache, f)

    def _load_cache(self):
        """
        Loads the API cache from disk.
        """
        if os.path.exists(self.cache_file_name):
            with open(self.cache_file_name, "rb") as f:
                return pickle.load(f)
        return {}

    def get_models(self):
        """
        Gets the list of available models from Ollama.
        """
        try:
            response = self._make_request("models", method="GET")
            return response.get("models", [])
        except Exception as e:
            logger.error(f"Error getting models: {e}")
            return []

    def _count_tokens(self, messages: list, model: str):
        """
        Count the number of tokens in a list of messages using Ollama's API.

        Args:
            messages (list): A list of dictionaries representing the conversation history.
            model (str): The name of the model to use for encoding the string.

        Returns:
            int or None: The number of tokens in the messages, or None if an error occurs.
        """
        try:
            # Combine all message content into a single string
            combined_text = ""
            for message in messages:
                # Add role/name if present
                if "name" in message:
                    combined_text += f"{message['name']}: "
                if "role" in message:
                    combined_text += f"{message['role']}: "
                # Add message content
                if "content" in message:
                    combined_text += f"{message['content']}\n"

            # Prepare the request payload
            payload = {
                "model": model,
                "input": combined_text,
                "options": {
                    "temperature": 0  # Set to 0 since we only care about token count
                },
            }

            # Make the request to Ollama's API
            temp_url = self.base_url.replace(
                "/v1", ""
            )  # Not sure what happened in their API, complete hack
            response = requests.post(f"{temp_url}/api/embed", json=payload)
            response.raise_for_status()

            # Extract token count from response
            data = response.json()
            token_count = data.get("prompt_eval_count", 0)

            return token_count

        except requests.exceptions.RequestException as e:
            logger.error(f"Error making request to Ollama API: {e}")
            return None
        except Exception as e:
            logger.error(f"Error counting tokens: {e}")
            return None

Methods

def get_models(self)

Gets the list of available models from Ollama.

Expand source code
def get_models(self):
    """
    Gets the list of available models from Ollama.
    """
    try:
        response = self._make_request("models", method="GET")
        return response.get("models", [])
    except Exception as e:
        logger.error(f"Error getting models: {e}")
        return []
def send_message(self, current_messages, dedent_messages=True, model=None, temperature=None, max_completion_tokens=None, top_p=None, frequency_penalty=None, presence_penalty=None, stop=None, num_ctx=None, timeout=None, max_attempts=None, waiting_time=None, exponential_backoff_factor=None, n=1, response_format=None, enable_pydantic_model_return=False, echo=False)

Sends a message to the Ollama API and returns the response.

Expand source code
@config_manager.config_defaults(
    model="model",
    temperature="temperature",
    top_p="top_p",
    frequency_penalty="frequency_penalty",
    presence_penalty="presence_penalty",
    num_ctx="num_ctx",
    timeout="timeout",
    max_attempts="max_attempts",
    waiting_time="waiting_time",
    exponential_backoff_factor="exponential_backoff_factor",
    response_format=None,
    echo=None,
)
def send_message(
    self,
    current_messages,
    dedent_messages=True,
    model=None,
    temperature=None,
    max_completion_tokens=None,  # Ollama doesn't use max_completion_tokens
    top_p=None,
    frequency_penalty=None,
    presence_penalty=None,
    stop=None,
    num_ctx=None,
    timeout=None,
    max_attempts=None,
    waiting_time=None,
    exponential_backoff_factor=None,
    n=1,
    response_format=None,
    enable_pydantic_model_return=False,
    echo=False,
):
    """
    Sends a message to the Ollama API and returns the response.
    """

    from tinytroupe.clients import (  # avoid circular import
        InvalidRequestError,
        NonTerminalError,
    )

    def aux_exponential_backoff():
        nonlocal waiting_time
        logger.info(
            f"Request failed. Waiting {waiting_time} seconds between requests..."
        )
        time.sleep(waiting_time)
        waiting_time = waiting_time * exponential_backoff_factor

    # Prepare the API parameters
    chat_api_params = {
        "model": model,
        "messages": current_messages,
        "options": {
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "stop": stop,
            "num_ctx": num_ctx,  # special Ollama parameter for the input size
        },
        "stream": False,
        "n": n,
    }

    # remove any parameter that is None, so we use the API defaults
    chat_api_params = {k: v for k, v in chat_api_params.items() if v is not None}
    # ... within options too
    chat_api_params["options"] = {
        k: v for k, v in chat_api_params["options"].items() if v is not None
    }

    i = 0
    while i < max_attempts:
        try:
            i += 1

            start_time = time.monotonic()
            logger.debug(f"Sending request to Ollama API. Attempt {i}")

            # Check cache first
            cache_key = str((model, chat_api_params))
            if self.cache_api_calls and (cache_key in self.api_cache):
                response = self.api_cache[cache_key]
            else:
                logger.info(
                    f"Waiting {waiting_time} seconds before next API request..."
                )
                time.sleep(waiting_time)

                # Make the API call
                response = self._make_request(
                    "chat/completions",
                    method="POST",
                    json=chat_api_params,
                    timeout=timeout,
                )

                # Cache the response if caching is enabled
                if self.cache_api_calls:
                    self.api_cache[cache_key] = response
                    self._save_cache()

            end_time = time.monotonic()
            logger.debug(
                f"Got response in {end_time - start_time:.2f} seconds after {i} attempts"
            )

            # Extract and return the relevant part of the response
            return utils.sanitize_dict(self._extract_response(response))

        except requests.exceptions.RequestException as e:
            logger.error(f"[{i}] Request error: {e}")
            if "Invalid request" in str(e):
                raise InvalidRequestError(str(e))
            aux_exponential_backoff()

        except Exception as e:
            logger.error(f"[{i}] Error: {e}")
            aux_exponential_backoff()

    logger.error(f"Failed to get response after {max_attempts} attempts")
    return None
def set_api_cache(self, cache_api_calls, cache_file_name=None)

Enables or disables the caching of API calls.

Args: cache_file_name (str): The name of the file to use for caching API calls.

Expand source code
def set_api_cache(self, cache_api_calls, cache_file_name=None):
    """
    Enables or disables the caching of API calls.

    Args:
    cache_file_name (str): The name of the file to use for caching API calls.
    """
    self.cache_api_calls = cache_api_calls
    self.cache_file_name = cache_file_name
    if self.cache_api_calls:
        # load the cache, if any
        self.api_cache = self._load_cache()