RAG

6. RAG#

In Retrieval-Augmented Generation (RAG), a set of results is retrieved via some ranking algorithm and then added to the prompt to provide additional context. SAMMO has built-in in-memory RAG for smaller scales, but you can implement your own Component to connect it to other backends.

Hide code cell source
# %load -r 3:25 _init.py
import pathlib
import sammo
from sammo.runners import OpenAIChat, OpenAIEmbedding
from sammo.base import Template, EvaluationScore
from sammo.components import Output, GenerateText, ForEach, Union
from sammo.extractors import ExtractRegex
from sammo.data import DataTable
import json
import requests
import os

if not "OPENAI_API_KEY" in os.environ:
    raise ValueError("Please set the environment variable OPENAI_API_KEY'.")

_ = sammo.setup_logger("WARNING")  # we're only interested in warnings for now

runner = OpenAIChat(
    model_id="gpt-3.5-turbo-16k",
    api_config={"api_key": os.getenv("OPENAI_API_KEY")},
    cache=os.getenv("CACHE_FILE", "cache.tsv"),
    timeout=30,
)

In addition to the normal LLM runner, we also need to provide a runner for the embeddings:

embedder = OpenAIEmbedding(
    model_id="text-embedding-3-small",
    api_config={"api_key": os.getenv("OPENAI_API_KEY")},
    rate_limit=10,
    cache=os.getenv("EMBEDDING_FILE", "embeddings.tsv"),
)

Let’s generate a bit of mock data:

fruits = [
    {"category": "fruit", "name": x, "description": f"Article about {x}"}
    for x in ["mango", "banana", "apple", "orange", "grapes"]
]
vegetables = [
    {"category": "vegetable", "name": x, "description": f"Article about {x}"}
    for x in ["cucumber", "tomato", "carrot", "onion", "garlic"]
]
data = DataTable(fruits + vegetables)
d_fewshot, d_train = data.random_split(9, 1)
d_train
+------------------------------------------------------------+----------+
| input                                                      | output   |
+============================================================+==========+
| {'category': 'vegetable', 'name': 'carrot', 'description': | None     |
| 'Article about carrot'}                                    |          |
+------------------------------------------------------------+----------+
Constants: None

So our query item is a vegetable. We expect to only get vegetables from the fewshot retriever then, too. Okay, the next step is to assemble a prompt.

from sammo.instructions import Section, InputData, EmbeddingFewshotExamples, MetaPrompt
from sammo.dataformatters import PlainFormatter

structure = [
    Section(
        "Examples",
        EmbeddingFewshotExamples(
            embedder,
            d_fewshot,
            n_examples=3,
            budget="relative",
        ),
    ),
    Section(
        "Question",
        "How many vegetables and how many fruits are above?",
    ),
]
rag_prompt = Output(MetaPrompt(structure, render_as="markdown", data_formatter=PlainFormatter()).with_extractor())

The EmbeddingFewshotExamples renders the input column of the DataTable in its canonical Python format before embedding it. Time to run it:

results = rag_prompt.run(runner, d_train)
results
+------------------------------------------------------------+------------------------------------------------+
| input                                                      | output                                         |
+============================================================+================================================+
| {'category': 'vegetable', 'name': 'carrot', 'description': | ['There are 3 vegetables and 0 fruits above.'] |
| 'Article about carrot'}                                    |                                                |
+------------------------------------------------------------+------------------------------------------------+
Constants: None

Nice! We indeed got the best matches. Let’s see what was going on behind the scenes.

results.outputs[0].plot_call_trace()

We can see that the EmbeddingFewshotExamples component retrieved all relevant vegetables (click on it to see its output) from our dataset before including them in the prompt.