Custom Runners

Hide code cell source
import sammo
import getpass

from sammo.components import GenerateText, Output
from sammo.utils import serialize_json

from sammo.base import LLMResult, Costs
from sammo.runners import BaseRunner

_ = sammo.setup_logger("WARNING")  # we're only interested in warnings for now

Custom Runners#

To call different backends, SAMMO supports custom implementations via Runner~. If you have a REST API that youโ€™d like to call, the simplest way to do it is to inherit from RestRunner~.

In this tutorial, we will write a custom runner to generate text via a DeepInfra endpoint.

 1from sammo.components import GenerateText, Output
 2from sammo.utils import serialize_json
 3
 4from sammo.base import LLMResult, Costs
 5from sammo.runners import RestRunner
 6
 7
 8class DeepInfraChat(RestRunner):
 9    async def generate_text(
10        self,
11        prompt: str,
12        max_tokens: int | None = None,
13        randomness: float | None = 0,
14        seed: int = 0,
15        priority: int = 0,
16        **kwargs,
17    ) -> LLMResult:
18        formatted_prompt = f"[INST] {prompt} [/INST]"
19        request = dict(
20            input=formatted_prompt,
21            max_new_tokens=self._max_context_window or max_tokens,
22            temperature=randomness,
23        )
24        fingerprint = serialize_json({"seed": seed, "generative_model_id": self._model_id, **request})
25        return await self._execute_request(request, fingerprint, priority)
26
27    async def _call_backend(self, request: dict) -> dict:
28        async with self._get_session() as session:
29            async with session.post(
30                f"https://api.deepinfra.com/v1/inference/{self._model_id}",
31                json=request,
32                headers={"Authorization": f"Bearer {self._api_config['api_key']}"}
33            ) as response:
34                return await response.json()
35
36    def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes):
37        return LLMResult(
38            json_data["results"][0]["generated_text"],
39            costs=Costs(json_data["num_input_tokens"], json_data["num_tokens"]),
40        )
runner = DeepInfraChat(
    "mistralai/Mixtral-8x7B-Instruct-v0.1", api_config={"api_key": getpass.getpass("Enter your API key")}
)
print(Output(GenerateText("Generate a 50 word essay about horses.")).run(runner))
+---------+-------------------------------------------------------------+
| input   | output                                                      |
+=========+=============================================================+
| None    | Horses, majestic creatures, have accompanied humans for     |
|         | thousands of years, serving in transportation, agriculture, |
|         | and warfare. Today, they are cherished for companionship,   |
|         | sport, and therapy. With their powerful build, graceful     |
|         | movements, and intuitive nature, horses continue to inspire |
|         | and connect us to the natural world. Their enduring bond    |
|         | with humans is a testament to their intelligence and        |
|         | emotional depth.                                            |
+---------+-------------------------------------------------------------+
Constants: None

The three things we had to implement were

  1. generate_text(): To format the prompt into a dictionary and compute a fingerprint for

  2. _call_backend(): To make the actual REST request

  3. _to_llm_result(): To convert the JSON object into an LLM result instance.

Thatโ€™s it! The parent class will take care of all caching.