In [1]:
## Load from parent directory if not installed
import importlib

if not importlib.util.find_spec("sammo"):
    import sys

    sys.path.append("../../")

CACHE_FILE = "cache/custom_runners.tsv"

In [3]:
import sammo
import getpass

from sammo.components import GenerateText, Output
from sammo.utils import serialize_json

from sammo.base import LLMResult, Costs
from sammo.runners import BaseRunner

_ = sammo.setup_logger("WARNING")  # we're only interested in warnings for now

# Custom Runners

To call different backends, `SAMMO` supports custom implementations via {class}`~sammo.base.Runner~`. If you have a REST API that you'd like to call, the simplest way to do it is to inherit from {class}`~sammo.runners.RestRunner~`.

In this tutorial, we will write a custom runner to generate text via a DeepInfra endpoint.

In [4]:
from sammo.components import GenerateText, Output
from sammo.utils import serialize_json

from sammo.base import LLMResult, Costs
from sammo.runners import RestRunner


class DeepInfraChat(RestRunner):
    async def generate_text(
        self,
        prompt: str,
        max_tokens: int | None = None,
        randomness: float | None = 0,
        seed: int = 0,
        priority: int = 0,
        **kwargs,
    ) -> LLMResult:
        formatted_prompt = f"[INST] {prompt} [/INST]"
        request = dict(
            input=formatted_prompt,
            max_new_tokens=self._max_context_window or max_tokens,
            temperature=randomness,
        )
        fingerprint = serialize_json({"seed": seed, "generative_model_id": self._model_id, **request})
        return await self._execute_request(request, fingerprint, priority)

    async def _call_backend(self, request: dict) -> dict:
        async with self._get_session() as session:
            async with session.post(
                f"https://api.deepinfra.com/v1/inference/{self._model_id}",
                json=request,
                headers={"Authorization": f"Bearer {self._api_config['api_key']}"}
            ) as response:
                return await response.json()

    def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes):
        return LLMResult(
            json_data["results"][0]["generated_text"],
            costs=Costs(json_data["num_input_tokens"], json_data["num_tokens"]),
        )

In [6]:
runner = DeepInfraChat(
    "mistralai/Mixtral-8x7B-Instruct-v0.1", api_config={"api_key": getpass.getpass("Enter your API key")}
)
print(Output(GenerateText("Generate a 50 word essay about horses.")).run(runner))

Enter your API key ········


+---------+-------------------------------------------------------------+
| input   | output                                                      |
| None    | Horses, majestic creatures, have accompanied humans for     |
|         | thousands of years, serving in transportation, agriculture, |
|         | and warfare. Today, they are cherished for companionship,   |
|         | sport, and therapy. With their powerful build, graceful     |
|         | movements, and intuitive nature, horses continue to inspire |
|         | and connect us to the natural world. Their enduring bond    |
|         | with humans is a testament to their intelligence and        |
|         | emotional depth.                                            |
+---------+-------------------------------------------------------------+
Constants: None


The three things we had to implement were

1. `generate_text()`: To format the prompt into a dictionary and compute a fingerprint for 
2. `_call_backend()`: To make the actual REST request
3. `_to_llm_result()`: To convert the JSON object into an LLM result instance.

That's it! The parent class will take care of all caching.