Batch Optimization

Batch Optimization#

We provide an example of how to update parameters on a batch of data. In these toy examples, we show different ways to update parameters of functions on data containing multiple inputs. For simplicity, we consider batch update without random sampling.

%pip install trace-opt
Looking in indexes: https://pypi.netflix.net/simple
Requirement already satisfied: trace-opt in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (0.1.1)
Requirement already satisfied: autogen-agentchat~=0.2 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from trace-opt) (0.2.37)
Requirement already satisfied: graphviz>=0.20.1 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from trace-opt) (0.20.3)
Requirement already satisfied: scikit-learn in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from trace-opt) (1.5.1)
Requirement already satisfied: xgboost in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from trace-opt) (2.1.1)
Requirement already satisfied: diskcache in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (5.6.3)
Requirement already satisfied: docker in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (7.1.0)
Requirement already satisfied: flaml in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (2.3.1)
Requirement already satisfied: numpy<2,>=1.17.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (1.26.4)
Requirement already satisfied: openai>=1.3 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (1.52.2)
Requirement already satisfied: packaging in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (24.1)
Requirement already satisfied: pydantic!=2.6.0,<3,>=1.10 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (2.9.2)
Requirement already satisfied: python-dotenv in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (1.0.1)
Requirement already satisfied: termcolor in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (2.5.0)
Requirement already satisfied: tiktoken in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from autogen-agentchat~=0.2->trace-opt) (0.8.0)
Requirement already satisfied: scipy>=1.6.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from scikit-learn->trace-opt) (1.13.1)
Requirement already satisfied: joblib>=1.2.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from scikit-learn->trace-opt) (1.4.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from scikit-learn->trace-opt) (3.5.0)
Requirement already satisfied: anyio<5,>=3.5.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (4.6.2.post1)
Requirement already satisfied: distro<2,>=1.7.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (1.9.0)
Requirement already satisfied: httpx<1,>=0.23.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (0.27.2)
Requirement already satisfied: jiter<1,>=0.4.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (0.6.1)
Requirement already satisfied: sniffio in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (1.3.1)
Requirement already satisfied: tqdm>4 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (4.66.6)
Requirement already satisfied: typing-extensions<5,>=4.11 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from openai>=1.3->autogen-agentchat~=0.2->trace-opt) (4.12.2)
Requirement already satisfied: annotated-types>=0.6.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from pydantic!=2.6.0,<3,>=1.10->autogen-agentchat~=0.2->trace-opt) (0.7.0)
Requirement already satisfied: pydantic-core==2.23.4 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from pydantic!=2.6.0,<3,>=1.10->autogen-agentchat~=0.2->trace-opt) (2.23.4)
Requirement already satisfied: requests>=2.26.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from docker->autogen-agentchat~=0.2->trace-opt) (2.32.3)
Requirement already satisfied: urllib3>=1.26.0 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from docker->autogen-agentchat~=0.2->trace-opt) (2.2.3)
Requirement already satisfied: regex>=2022.1.18 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from tiktoken->autogen-agentchat~=0.2->trace-opt) (2024.9.11)
Requirement already satisfied: idna>=2.8 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from anyio<5,>=3.5.0->openai>=1.3->autogen-agentchat~=0.2->trace-opt) (3.7)
Requirement already satisfied: exceptiongroup>=1.0.2 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from anyio<5,>=3.5.0->openai>=1.3->autogen-agentchat~=0.2->trace-opt) (1.2.2)
Requirement already satisfied: certifi in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai>=1.3->autogen-agentchat~=0.2->trace-opt) (2024.8.30)
Requirement already satisfied: httpcore==1.* in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai>=1.3->autogen-agentchat~=0.2->trace-opt) (1.0.6)
Requirement already satisfied: h11<0.15,>=0.13 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai>=1.3->autogen-agentchat~=0.2->trace-opt) (0.14.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/aswaminathan/miniconda3/envs/trace/lib/python3.10/site-packages (from requests>=2.26.0->docker->autogen-agentchat~=0.2->trace-opt) (3.3.2)
Note: you may need to restart the kernel to use updated packages.

First, we consider a small linear regression problem. To perform updates on multiple inputs at a time, here we just compute the loss for each input and then sum it up, and perform one backward call to tell the optimizer to minimize the loss. Since the optimizer is capable of seeing the graph, it can understand how different inputs and labels are paired and evaluated by the loss function.

import random
import numpy as np

random.seed(0)
np.random.seed(0)

from opto import trace
from opto.optimizers import OptoPrime


def true_fun(x):
    return 2*x  - 3

inputs = [3, 2, 1, 5, 4]
outputs = [true_fun(x) for x in inputs]
N = len(inputs)


@trace.bundle()
def loss(y_hat, y):
    """ A least squares loss function. """
    return (y_hat - y) ** 2


def compute_loss(inputs, outputs):
    l = 0
    for x,y in zip(inputs, outputs):
        y_hat = fun(x)
        l += loss(y_hat, y)
    return l
trace.GRAPH.clear()

@trace.bundle(trainable=True)
def fun(x):
    """ A linear predictor function """
    return 0

optimizer = OptoPrime(fun.parameters())

ls = []
for i in range(15):
    try:
        l = compute_loss(inputs, outputs)
        target = l
        feedback = 'Minimize loss'
        print(f'Iteration {i} Loss: {l.data}')
        ls.append(l.data)
    except trace.ExecutionError as e:
        target = e.exception_node
        feedback = str(e.exception_node.data)

    optimizer.zero_feedback()
    optimizer.backward(target, feedback)
    optimizer.step()

# plot ls
import matplotlib.pyplot as plt
plt.plot(ls)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
Iteration 0 Loss: 85
Iteration 1 Loss: 10
Iteration 2 Loss: 10
Iteration 3 Loss: 7.5
Iteration 4 Loss: 122.8125
Iteration 5 Loss: 80.3125
Iteration 6 Loss: 12.8125
Iteration 7 Loss: 10.0
Iteration 8 Loss: 7.5
Iteration 9 Loss: 8.150000000000002
Iteration 10 Loss: 6.449999999999999
Iteration 11 Loss: 8.150000000000002
Iteration 12 Loss: 9.037500000000001
Iteration 13 Loss: 9.427
../_images/62023702b92f5fee65b9d6edc71a37f8af3d06de89bb07b97e81d4112b377a6a.png

In contrast, if we update the parameter without batching but in a purely online fashion one by one, then the optimization results can be more noisy sometimes.

trace.GRAPH.clear()

@trace.bundle(trainable=True)
def fun(x):
    """ A linear predictor function """
    return 0

optimizer = OptoPrime(fun.parameters())

ls = []
for i in range(15):
    try:
        l_eval = compute_loss(inputs, outputs)
        print(f'Iteration {i} Loss: {l_eval.data}')
        ls.append(l_eval.data)

        ind = np.random.randint(0, N) % N
        target = compute_loss([inputs[ind]], [outputs[ind]])
        feedback = 'Minimize loss'
    except trace.ExecutionError as e:
        target = e.exception_node
        feedback = str(e.exception_node.data)

    optimizer.zero_feedback()
    optimizer.backward(target, feedback)
    optimizer.step()



# plot ls
import matplotlib.pyplot as plt
plt.plot(ls)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()
Iteration 0 Loss: 85
Iteration 1 Loss: 15
Iteration 2 Loss: 10
Iteration 4 Loss: 10
Iteration 5 Loss: 6
Iteration 6 Loss: 6
Iteration 7 Loss: 5
Iteration 8 Loss: 5
Iteration 9 Loss: 1
Iteration 10 Loss: 0
Iteration 11 Loss: 0
Iteration 12 Loss: 0
Iteration 13 Loss: 9
Iteration 14 Loss: 120
../_images/d8a07a66ca17fab6650a0005c4c25c9a841248a969a4b393c54e7c453fad371e.png

Batching Non-Commutative Feedbacks#

In the earlier numerical example, the loss function was commutative so that we can do batch_loss += loss(each_input). What if the feedbacks received are not commutative? This can happen often with non-numeric (e.g. text) feedbacks. Here we will see a simple design pattern for using trace and OptoPrime for batch optimization in such cases.

from opto.trace import bundle

@bundle(trainable=False)
def concat(*items):
    """ Concatenate the items into a single string """
    output = ''
    for i, item in enumerate(items):
        output += f'ID {[i]}: {item}\n'
    return output

Note that the concat function when called with a list of feedbacks will concatenate them all with an identifier for each element. This way, the optimizer when given a batch of outputs and a corresponding batch of feedbacks can disambiguate which feedback corresponds to which output.

@bundle(trainable=True)
def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.
    '''
    lst = sorted(lst)
    return lst

def get_feedback(predict, target):
    if predict == target:
        return "test case passed!"
    else:
        return "test case failed!"
    
from opto.optimizers import OptoPrime

test_ground_truths = [[1, 4, 2, 3], [5, 5, 5, 5], [], [4, 9, 5, 8, 6, 7]]
test_inputs = [[1, 2, 3, 4], [5, 5, 5, 5], [], [9, 8, 7, 6, 5, 4]]

optimizer = OptoPrime(strange_sort_list.parameters())

outputs = []
feedbacks = []
for i in range(len(test_inputs)):
    try:
        test_output = strange_sort_list(test_inputs[i])
        feedback = get_feedback(test_output, test_ground_truths[i])
    except trace.ExecutionError as e:
        feedback = e.exception_node.data
        test_output = e.exception_node
    feedbacks.append(feedback)
    
    correctness = test_output.eq(test_ground_truths[i])
    outputs.append(correctness)

batched_feedback = concat(*feedbacks)
batched_outputs = concat(*outputs)
optimizer.zero_feedback()
optimizer.backward(batched_outputs, batched_feedback.data)
optimizer.step(verbose=True)
Prompt
 
You're tasked to solve a coding/algorithm problem. You will see the instruction, the code, the documentation of each function used in the code, and the feedback about the execution result.

Specifically, a problem will be composed of the following parts:
- #Instruction: the instruction which describes the things you need to do or the question you should answer.
- #Code: the code defined in the problem.
- #Documentation: the documentation of each function used in #Code. The explanation might be incomplete and just contain high-level description. You can use the values in #Others to help infer how those functions work.
- #Variables: the input variables that you can change.
- #Constraints: the constraints or descriptions of the variables in #Variables.
- #Inputs: the values of other inputs to the code, which are not changeable.
- #Others: the intermediate values created through the code execution.
- #Outputs: the result of the code output.
- #Feedback: the feedback about the code's execution result.

In #Variables, #Inputs, #Outputs, and #Others, the format is:

<data_type> <variable_name> = <value>

If <type> is (code), it means <value> is the source code of a python code, which may include docstring and definitions.

Output_format: Your output should be in the following json format, satisfying the json syntax:

{{
"reasoning": <Your reasoning>,
"answer": <Your answer>,
"suggestion": {{
    <variable_1>: <suggested_value_1>,
    <variable_2>: <suggested_value_2>,
}}
}}

In "reasoning", explain the problem: 1. what the #Instruction means 2. what the #Feedback on #Output means to #Variables considering how #Variables are used in #Code and other values in #Documentation, #Inputs, #Others. 3. Reasoning about the suggested changes in #Variables (if needed) and the expected result.

If #Instruction asks for an answer, write it down in "answer".

If you need to suggest a change in the values of #Variables, write down the suggested values in "suggestion". Remember you can change only the values in #Variables, not others. When <type> of a variable is (code), you should write the new definition in the format of python code without syntax errors, and you should not change the function name or the function signature.

If no changes or answer are needed, just output TERMINATE.

Now you see problem instance:

================================

#Instruction
You need to change the <value> of the variables in #Variables to improve the output in accordance to #Feedback.

#Code
eval84 = eval(lst=lst0, __code=__code1)
eval85 = eval(lst=lst1, __code=__code1)
eval86 = eval(lst=lst2, __code=__code1)
eval87 = eval(lst=lst3, __code=__code1)
eq0 = eq(x=eval84, y=list0)
eq1 = eq(x=eval85, y=list1)
eq2 = eq(x=eval86, y=list2)
eq3 = eq(x=eval87, y=list3)
concat1 = concat(args_0=eq0, args_1=eq1, args_2=eq2, args_3=eq3)

#Documentation
[eval] This operator eval(__code, *args, **kwargs) evaluates the code block, where __code is the code (str) and *args and **kwargs are the arguments of the function. The output is the result of the evaluation, i.e., __code(*args, **kwargs).
[eq] This is an eq operator of x and y. .
[concat] Concatenate the items into a single string .

#Variables
(code) __code1:def strange_sort_list(lst):
    '''
    Given list of integers, return list in strange order.
    Strange sorting, is when you start with the minimum value,
    then maximum of the remaining integers, then minimum and so on.
    '''
    lst = sorted(lst)
    return lst

#Constraints
(code) __code1: The code should start with:
def strange_sort_list(lst):

#Inputs
(list) lst1=[5, 5, 5, 5]
(list) lst2=[]
(list) lst3=[9, 8, 7, 6, 5, 4]
(list) lst0=[1, 2, 3, 4]
(list) list1=[5, 5, 5, 5]
(list) list2=[]
(list) list3=[4, 9, 5, 8, 6, 7]
(list) list0=[1, 4, 2, 3]

#Others
(list) eval85=[5, 5, 5, 5]
(list) eval86=[]
(list) eval87=[4, 5, 6, 7, 8, 9]
(list) eval84=[1, 2, 3, 4]
(bool) eq0=False
(bool) eq1=True
(bool) eq2=True
(bool) eq3=False

#Outputs
(str) concat1=ID [0]: False
ID [1]: True
ID [2]: True
ID [3]: False


#Feedback
ID [0]: test case failed!
ID [1]: test case passed!
ID [2]: test case passed!
ID [3]: test case failed!


================================


Your response:

LLM response:
 {
    "reasoning": "The feedback indicates that the eval functions are used to generate sorted lists through a given sort function defined in __code1. The eq functions compare the output of these sort functions with predefined lists (list0, list1, list2, list3). The tests for lst0 and lst3 have failed, indicating that the output from the strange_sort_list function did not match the expected pattern given in list0 and list3. From the documentation and the given list patterns, 'strange_sort_list' should alternate between the smallest and largest remaining elements in the list, but the function in __code1 currently only sorts the list in ascending order. Thus, lst0 and lst3 are not sorting correctly into their expected 'strange' order.",
    "answer": "Change __code1 to implement the strange sorting pattern by alternating between selecting minimum and maximum elements.",
    "suggestion": {
        "__code1": "def strange_sort_list(lst):\n    lst = sorted(lst)\n    result = []\n    while lst:\n        result.append(lst.pop(0))  # append and remove the first (minimum)\n        if lst:  # check to avoid popping from an empty list\n            result.append(lst.pop()) # append and remove the last (maximum)\n    return result\n"
    }
}