⚡️ First: 5-Minute Basics#

Writing a Trace program involves three simple steps.

  1. Write a normal python program.

  2. Decorate the program with Trace primitives: node and @bundle.

  3. Use a Trace-graph-compatible optimizer to optimize the program.

In this short tutorial, we will look at a simple example of how Trace allows code optimization on runnable Python code. Then, we show how you can use Trace to create a code generation and verifier framework by alternating between optimizing the code and the unit tests with a dozen lines of code.

%pip install trace-opt
Hide code cell source
import os
import ipywidgets as widgets
from IPython.display import display

# Function to save the environment variable and API key
def save_env_variable(env_name, api_key):
    # Validate inputs
    if not env_name.strip():
        print("⚠️ Environment variable name cannot be empty.")
        return
    if not api_key.strip():
        print("⚠️ API key cannot be empty.")
        return
    
    # Store the API key as an environment variable
    os.environ[env_name] = api_key
    globals()[env_name] = api_key  # Set it as a global variable
    print(f"✅ API key has been set for environment variable: {env_name}")

# Create the input widgets
env_name_input = widgets.Text(
    value="OPENAI_API_KEY",  # Default value
    description="Env Name:",
    placeholder="Enter env variable name (e.g., MY_API_KEY)",
)

api_key_input = widgets.Password(
    description="API Key:",
    placeholder="Enter your API key",
)

# Create the button to submit the inputs
submit_button = widgets.Button(description="Set API Key")

# Display the widgets
display(env_name_input, api_key_input, submit_button)

# Callback function for the button click
def on_button_click(b):
    env_name = env_name_input.value
    api_key = api_key_input.value
    save_env_variable(env_name, api_key)

# Attach the callback to the button
submit_button.on_click(on_button_click)

Start with a Normal Python Program#

To illustrate the fundamental difference between Trace and many other LLM libraries, we can look at a coding problem from OpenAI’s Human-Eval dataset.

def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]
    strange_sort_list([]) == []
    '''
    lst = sorted(lst)
    return lst
strange_sort_list([1, 2, 3, 4])
[1, 2, 3, 4]

Our first attempt is not very successful – we tried one of the example input and the output is not correct. Can we leverage an LLM to help us automatically figure out the right solution?

Trace this Program#

We use decorators like @bundle to wrap over Python functions. A bundled function behaves like any other Python functions.

from opto.trace import node, bundle

@bundle(trainable=True)
def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]
    strange_sort_list([]) == []
    '''
    lst = sorted(lst)
    return lst

test_input = [1, 2, 3, 4]
test_output = strange_sort_list(test_input)
print(test_output)
MessageNode: (eval:0, dtype=<class 'list'>, data=[1, 2, 3, 4])

Note that even though strange_sort_list is now bundled, it still executes like a normal Python function. The returned value however, is a MessageNode. Trace automatically converts input to Node and start tracing operations. We can see the output is still the same as we had before.

correctness = test_output.eq([1, 4, 2, 3])
correctness.backward("test failed", visualize=True, print_limit=25)
../_images/d18a08a1759d23d7788d3e91ec657f09cfc3aab6a8696022db31cf8c5355f6a5.svg

We are not just tracing how the program is run – we are tracing how we verified the correctness of the output too, all without any manual intervention. Our design philosophy is to support un-interrupted Python programming.

Note

You may have noticed that we used test_output.eq([1, 4, 2, 3]) instead of the more conventional test_output == [1, 4, 2, 3]. You can try it! Actually, both statements will execute without an error. However, test_output == [1, 4, 2, 3] returns a Python boolean object False, while test_output.eq([1, 4, 2, 3]) returns a Trace MessageNode. We have taken great effor to make Node objects work just like any Python objects, but there are some corner cases to watch out for.

Optimize with Feedback#

We can use Trace and the optimizer defined in opto.optimizers to learn a program that can pass this test! But first, we need to provide some signal to the optimizer. We call this signal feedback. The more intelligent the optimizer is, the less guidance we need to give in the feedback. Think of this as us providing instructions to help LLM-based optimizers to come up with better solutions.

Note

You might be thinking – correctness is already a boolean variable and has a name correctness. Why do we need to explicitly provide a feedback string to the optimizer? This is because just by False or True – it is hard to know whether the test has passed or not. We could be testing on a positive or a negative target. Therefore, it is important to pass in an explicit feedback.

Tip

What do we have to write without Trace? If you are using a ChatGPT/Bard/Claude interface, you will copy paste the code into the interface, ask it to write code for you. Copy/paste the code back into a code editor that has Python runtime. Execute the Python code. Compare the output with the ground truth, and then write in the chat interface test case failed!. Trace simplified all of these steps.

def get_feedback(predict, target):
    if predict == target:
        return "test case passed!"
    else:
        return "test case failed!"

In order for the optimization code to run, create a file with name OAI_CONFIG_LIST in the same folder as this notebook. This file should look the same as OAI_CONFIG_LIST.

The code below looks like any PyTorch code. We can break it down to a few steps:

  1. We first import an optimizer from opto.optimizers.

  2. The bundled function strange_sort_list has a convenience method .parameters() for us to grab all the trainable parameters in it. In this case, the trainable parameter is just the function code itself.

  3. We need to perform backward pass on a MessageNode object, which is correctness. We also need to pass in the feedback into the optimizer as well.

  4. We call optimizer.step() to perform the actual optimization

import autogen
from opto.optimizers import OptoPrime
from opto import trace

test_ground_truth = [1, 4, 2, 3]
test_input = [1, 2, 3, 4]

epoch = 2

optimizer = OptoPrime(strange_sort_list.parameters(),
                      config_list=autogen.config_list_from_json("OAI_CONFIG_LIST"))

for i in range(epoch):
    print(f"Training Epoch {i}")
    try:
        test_output = strange_sort_list(test_input)
        feedback = get_feedback(test_output, test_ground_truth)
    except trace.ExecutionError as e:
        feedback = e.exception_node.data
        test_output = e.exception_node
    
    correctness = test_output.eq(test_ground_truth)
    
    if correctness:
        break

    optimizer.zero_feedback()
    optimizer.backward(correctness, feedback)
    optimizer.step()
Training Epoch 0
Training Epoch 1

We can visualize what the optimizer is doing. Trace first constructs a tuple (g0, f0), which corresponds to (Trace graph, feedback). Trace graph is a representation of what happened in the execution. Feedback is provided by the get_feedback function we defined above.

The LLM-based optimizer is then asked to generate a rationale/reasoning r1, and then an improvement a1 is proposed.

Note

Trace graph contains a #Code section. However, this does not look like the strange_sort_list function we wrote. The actual function is in the #Variables section. This is because we chose to use a specific optimizer OptoPrime, which frames the optimization problem as a code “debugging” problem, and uses the #Code section to represent the computation flow.

from opto.trace.utils import render_opt_step

render_opt_step(0, optimizer)

Trace Graph

#Code
eval1 = eval(lst=lst1, __code=__code0)
eq1 = eq(x=eval1, y=list1)

#Documentation
[eval] This operator eval(__code, *args, **kwargs) evaluates the code block, where __code is the code (str) and *args and **kwargs are the arguments of the function. The output is the result of the evaluation, i.e., __code(*args, **kwargs).
[eq] This is an eq operator of x and y. .

#Variables
(code) __code0:def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]
    strange_sort_list([]) == []
    '''
    lst = sorted(lst)
    return lst

#Constraints
(code) __code0: The code should start with:
def strange_sort_list(lst):

#Inputs
(list) lst1=[]
(list) list1=[1, 4, 2, 3]

#Outputs
(bool) eq1=False

g0

Feedback: test case failed!

f0

Reasoning: The objective in the code is to utilize a function defined in __code0, which is supposed to perform a 'strange sort' on a given list. The 'strange_sort_list' should alternate between the minimum and maximum elements of the list. The expectation is that the function will convert the input 'lst1' which is [1, 2, 3, 4] to a list sorted in a 'strange' order as [1, 4, 2, 3] as per the definition in the code documentation linked to __code0. However, the current implementation of the function in __code0 simply sorts the list in ascending order and returns it, as indicated by the output list [1, 2, 3, 4], which does not match what the 'strange_sort_list' function is supposed to implement. This discordance has led to the 'eq1' result being False as the expected [1, 4, 2, 3] does not equal [1, 2, 3, 4].

r1

Improvement

__code0:

def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]
    strange_sort_list([]) == []
    '''
    result = []
    ascending = True
    while lst:
        if ascending:
            value = min(lst)
            lst.remove(value)
        else:
            value = max(lst)
            lst.remove(value)
        result.append(value)
        ascending = not ascending
    return result

a1

Note

We can examine what the learned function look like by printing out its content below. Calling .data on a node returns the wrapped data inside the node.

print(strange_sort_list.parameters()[0].data)
def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Examples:
    strange_sort_list([1, 2, 3, 4]) == [1, 4, 2, 3]
    strange_sort_list([5, 5, 5, 5]) == [5, 5, 5, 5]
    strange_sort_list([]) == []
    '''
    result = []
    ascending = True
    while lst:
        if ascending:
            value = min(lst)
            lst.remove(value)
        else:
            value = max(lst)
            lst.remove(value)
        result.append(value)
        ascending = not ascending
    return result

The Optimized Function is Runnable#

You might wonder – oh, is this function actually changed? Yes it is. In fact, we can pass in a different test example and see if it worked.

new_test_input = [5, 3, 2, 5]
new_test_output = strange_sort_list(new_test_input)
print(new_test_output)
MessageNode: (eval:3, dtype=<class 'list'>, data=[2, 5, 3, 5])

Now the answer is correct! (alternating numbers between min, max, min, max!)

Save and Load the Optimized Function#

In fact, just like PyTorch, you can easily save and load this function too. We will discuss more details in the next section.

strange_sort_list.save("strange_sort_list.pkl")

Coder-Verifier Framework using Trace#

The functionalities provided by Trace seem simple and straightforward right now. But using these as basic building blocks, we can construct LLM-based workflow that is complex and mind-blowing.

Earlier, we used a unit test provided by the dataset. What if we don’t have any unit test, or if we wonder the tests are exhaustive or not? Trace takes an optimization perspective that can allow you to write multiple functions – we can write a function that proposes unit tests for another function. Let’s call this function a verifier.

Tip

Trace does not force users to wrap Python functions into a string in order to send to an optimizer. This allows natural Python programming flow. Trace manages the function body under the hood.

Note

Note that we removed the examples in the function docstring below!

from opto.trace import node, bundle, GRAPH
GRAPH.clear()

@bundle(trainable=True)
def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.
    '''
    return sorted(lst)

@bundle(trainable=True)
def verifier():
    """
    For a coding problem described below:
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Return a test input and an expected output to verify if the function is implemented correctly.
    """
    test_input = [1,2,3]
    expected_output = [1,2,3]
    return test_input, expected_output

test_input, expected_output =  verifier()
test_output = strange_sort_list(test_input)
print(test_output)
verifier_passed = test_output.eq(expected_output)
verifier_passed.backward("synthetic test result", visualize=True, print_limit=25)
MessageNode: (eval:1, dtype=<class 'list'>, data=[1, 2, 3])
../_images/77207ccb6eed2b4152b6632431ae4a85026062eb5f7f17631fb1c1c242e4c9b1.svg

Wow. It looks like we passed the synthetic test! Encouraging! Except, the proposed unit test is wrong. This is when the optimization perspective becomes important. Both the unit test function and the code writing function need to be optimized.

Let’s come up with a game that can help the code function to write better code.

Given a pair of ground truth input and expected output, we define the following procedure:

  1. Create two functions: coder and verifier.

  2. coder will keep optimizing its code unless the verifier agrees to let it pass.

  3. Now the coder will take the ground truth input and be checked against ground truth output.

  4. If the coder fails, verifier will optimize its internal tests.

Tip

In the Trace framework, feedback is similar to how we understand loss or metric in classical machine learning. We need to use it to guide the optimizer to change the the function’s behavior.

import autogen
from opto.optimizers import OptoPrime

GRAPH.clear()

test_ground_truth = [1, 4, 2, 3]
test_ground_truth_input = [1, 2, 3, 4]

epoch = 2

code_optimizer = OptoPrime(strange_sort_list.parameters(),
                          config_list=autogen.config_list_from_json("OAI_CONFIG_LIST"),
                          ignore_extraction_error=True)
verifier_optimizer = OptoPrime(verifier.parameters(),
                              config_list=autogen.config_list_from_json("OAI_CONFIG_LIST"),
                              ignore_extraction_error=True)

for i in range(epoch):
    print(f"Training Epoch {i}")

    # Step 2: Coder optimizes till it passes the verifier
    verifier_passed = False
    while not verifier_passed:
        print(f"Code function update")
        try:
            test_input, expected_output =  verifier()
            test_output = strange_sort_list(test_input)
            # note that we use the same feedback function as before
            feedback = get_feedback(test_output.data, expected_output.data)
        except trace.ExecutionError as e:
            feedback = e.exception_node.data
            test_output = e.exception_node
            expected_output = None
        verifier_passed = test_output.eq(expected_output)
        
        code_optimizer.zero_feedback()
        code_optimizer.backward(verifier_passed, feedback, retain_graph=True)
        code_optimizer.step()

    # Step 3: Coder is checked by ground truth
    try:
        test_output = strange_sort_list(test_ground_truth_input)
        feedback = get_feedback(test_output, test_ground_truth)
    except trace.ExecutionError as e:
        feedback = e.exception_node.data
        test_output = e.exception_node
    correctness = test_output.eq(test_ground_truth)
    
    if correctness:
        break

    # Step 4: If the Coder fails, Verifier needs to propose better unit tests
    print(f"Verifier update")
    feedback = "Verifier proposed wrong test_input, expected_output that do not validate the implementation of the function."
    verifier_optimizer.zero_feedback()
    verifier_optimizer.backward(verifier_passed, feedback)
    verifier_optimizer.step()
Training Epoch 0
Code function update
Verifier update
Training Epoch 1
Code function update
Code function update
Code function update
Code function update
print(strange_sort_list.parameters()[0])
ParameterNode: (__code:0, dtype=<class 'str'>, data=def strange_sort_list(lst):
    result = []
    min_index = 0
    max_index = len(lst) - 1
    sorted_lst = sorted(lst)
    while min_index <= max_index:
        result.append(sorted_lst[min_index])
        min_index += 1
        if min_index <= max_index:
            result.append(sorted_lst[max_index])
            max_index -= 1
    return result)
print(verifier.parameters()[0])
ParameterNode: (__code:1, dtype=<class 'str'>, data=def verifier():
    """
    For a coding problem described below:
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.

    Return a test input and an expected output to verify if the function is implemented correctly.
    """
    test_input = [3, 1, 2]
    expected_output = [1, 3, 2]
    return test_input, expected_output)

Note

[3, 1, 2] came up by the verifier is not in the function docstring examples.

What’s Next?#

We can see that the verifier can come up with good examples to test if the code function is correct, and the code function itself does not need ground-truth answers to update. All of these can be achieved within the optimization framework of Trace.

This is not the end of what Trace can do. Next, we will demonstrate how we can write a Python class object that can be optimized by Trace using what we learned here.