Show code cell source
import sammo
import getpass
from sammo.components import GenerateText, Output
from sammo.utils import serialize_json
from sammo.base import LLMResult, Costs
from sammo.runners import BaseRunner
_ = sammo.setup_logger("WARNING") # we're only interested in warnings for now
Custom Runners#
To call different backends, SAMMO
supports custom implementations via Runner~
. If you have a REST API that youโd like to call, the simplest way to do it is to inherit from RestRunner~
.
In this tutorial, we will write a custom runner to generate text via a DeepInfra endpoint.
1from sammo.components import GenerateText, Output
2from sammo.utils import serialize_json
3
4from sammo.base import LLMResult, Costs
5from sammo.runners import RestRunner
6
7
8class DeepInfraChat(RestRunner):
9 async def generate_text(
10 self,
11 prompt: str,
12 max_tokens: int | None = None,
13 randomness: float | None = 0,
14 seed: int = 0,
15 priority: int = 0,
16 **kwargs,
17 ) -> LLMResult:
18 formatted_prompt = f"[INST] {prompt} [/INST]"
19 request = dict(
20 input=formatted_prompt,
21 max_new_tokens=self._max_context_window or max_tokens,
22 temperature=randomness,
23 )
24 fingerprint = serialize_json({"seed": seed, "generative_model_id": self._model_id, **request})
25 return await self._execute_request(request, fingerprint, priority)
26
27 async def _call_backend(self, request: dict) -> dict:
28 async with self._get_session() as session:
29 async with session.post(
30 f"https://api.deepinfra.com/v1/inference/{self._model_id}",
31 json=request,
32 headers={"Authorization": f"Bearer {self._api_config['api_key']}"}
33 ) as response:
34 return await response.json()
35
36 def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes):
37 return LLMResult(
38 json_data["results"][0]["generated_text"],
39 costs=Costs(json_data["num_input_tokens"], json_data["num_tokens"]),
40 )
runner = DeepInfraChat(
"mistralai/Mixtral-8x7B-Instruct-v0.1", api_config={"api_key": getpass.getpass("Enter your API key")}
)
print(Output(GenerateText("Generate a 50 word essay about horses.")).run(runner))
+---------+-------------------------------------------------------------+
| input | output |
+=========+=============================================================+
| None | Horses, majestic creatures, have accompanied humans for |
| | thousands of years, serving in transportation, agriculture, |
| | and warfare. Today, they are cherished for companionship, |
| | sport, and therapy. With their powerful build, graceful |
| | movements, and intuitive nature, horses continue to inspire |
| | and connect us to the natural world. Their enduring bond |
| | with humans is a testament to their intelligence and |
| | emotional depth. |
+---------+-------------------------------------------------------------+
Constants: None
The three things we had to implement were
generate_text()
: To format the prompt into a dictionary and compute a fingerprint for_call_backend()
: To make the actual REST request_to_llm_result()
: To convert the JSON object into an LLM result instance.
Thatโs it! The parent class will take care of all caching.