from __future__ import annotations
import trace
from typing import TYPE_CHECKING, Any, Dict
if TYPE_CHECKING: # to prevent circular import
from opto.trace.nodes import Node
from opto.trace.bundle import bundle
import copy
@bundle()
def clone(x: Any):
"""This is a clone operator of x."""
return copy.deepcopy(x)
[docs]
def identity(x: Any):
# identity(x) behaves the same as x.clone()
return x.clone()
# Unary operators and functions
@bundle()
def pos(x: Any):
"""This is a pos operator of x."""
return +x
@bundle()
def neg(x: Any):
"""This is a neg operator of x."""
return -x
@bundle()
def abs(x: Any):
"""This is an abs operator of x."""
return abs(x)
@bundle()
def invert(x: Any):
"""This is an invert operator of x."""
return ~x
@bundle()
def round(x: Any, n: Any):
"""This is a round operator of x."""
return round(x, n)
@bundle()
def floor(x: Any):
"""This is a floor operator of x."""
import math
return math.floor(x)
@bundle()
def ceil(x: Any):
"""This is a ceil operator of x."""
import math
return math.ceil(x)
@bundle()
def trunc(x: Any):
"""This is a trunc operator of x."""
import math
return math.trunc(x)
# Normal arithmetic operators
@bundle()
def add(x: Any, y: Any):
"""This is an add operator of x and y."""
return x + y
@bundle()
def subtract(x: Any, y: Any):
"""This is a subtract operator of x and y."""
return x - y
@bundle()
def multiply(x: Any, y: Any):
"""This is a multiply operator of x and y."""
return x * y
@bundle()
def floor_divide(x: Any, y: Any):
"""This is a floor_divide operator of x and y."""
return x // y
@bundle()
def divide(x: Any, y: Any):
"""This is a divide operator of x and y."""
return x / y
@bundle()
def mod(x: Any, y: Any):
"""This is a mod operator of x and y."""
return x % y
@bundle()
def node_divmod(x: Any, y: Any):
"""This is a divmod operator of x and y."""
return divmod(x, y)
@bundle()
def power(x: Any, y: Any):
"""This is a power operator of x and y."""
return x**y
@bundle()
def lshift(x: Any, y: Any):
"""This is a lshift operator of x and y."""
return x << y
@bundle()
def rshift(x: Any, y: Any):
"""This is a rshift operator of x and y."""
return x >> y
@bundle()
def and_(x: Any, y: Any):
"""This is an and operator of x and y."""
return x & y
@bundle()
def or_(x: Any, y: Any):
"""This is an or operator of x and y."""
return x | y
@bundle()
def xor(x: Any, y: Any):
"""This is a xor operator of x and y."""
return x ^ y
# Comparison methods
@bundle()
def lt(x: Any, y: Any):
"""This is a lt operator of x and y."""
return x < y
@bundle()
def le(x: Any, y: Any):
"""This is a le operator of x and y."""
return x <= y
@bundle()
def eq(x: Any, y: Any):
"""This is an eq operator of x and y."""
return x == y
@bundle()
def neq(x: Any, y: Any):
"""This is a not eq operator of x and y."""
return x != y
@bundle()
def ne(x: Any, y: Any):
"""This is a ne operator of x and y."""
return x != y
@bundle()
def ge(x: Any, y: Any):
"""This is a ge operator of x and y."""
return x >= y
@bundle()
def gt(x: Any, y: Any):
"""This is a gt operator of x and y."""
return x > y
# logical operators
@bundle()
def cond(condition: Any, x: Any, y: Any):
"""This selects x if condition is True, otherwise y."""
x, y, condition = x, y, condition # This makes sure all data are read
return x if condition else y
@bundle()
def not_(x: Any):
"""This is a not operator of x."""
return not x
@bundle()
def is_(x: Any, y: Any):
"""Whether x is equal to y."""
return x is y
@bundle()
def is_not(x: Any, y: Any):
"""Whether x is not equal to y."""
return x is not y
@bundle()
def in_(x: Any, y: Any):
"""Whether x is in y."""
return x in y
@bundle()
def not_in(x: Any, y: Any):
"""Whether x is not in y."""
return x not in y
# Indexing and slicing
@bundle()
def getitem(x: Any, index: Any):
"""This is a getitem operator of x based on index."""
return x[index]
@bundle()
def pop(x: Any, index: Any):
"""This is a pop operator of x based on index."""
return x.pop(index)
@bundle()
def len_(x: Any):
"""This is a len operator of x."""
return len(x)
# String operators
@bundle()
def ord_(x: Any):
"""The unicode number of a character."""
return ord(x)
@bundle()
def chr_(x: Any):
"""The character of a unicode number."""
return chr(x)
@bundle()
def concat(x: Any, y: Any):
"""This is a concatenation operator of x and y."""
return x + y
@bundle()
def lower(x: Any):
"""This makes all characters in x lower case."""
return x.lower()
@bundle()
def upper(x: Any):
"""This makes all characters in x upper case."""
return x.upper()
@bundle()
def title(x: Any):
"""This makes the first character to upper case and the rest to lower case."""
return x.title()
@bundle()
def swapcase(x: Any):
"""Swaps the case of all characters: uppercase character to lowercase and vice-versa."""
return x.swapcase()
@bundle()
def capitalize(x: Any):
"""Converts the first character of a string to uppercase."""
return x.capitalize()
@bundle()
def split(x: Any, y: Any, maxsplit: Any = -1):
"""Splits the string by finding a substring y in string x, return the first part and second part of string x without y."""
return x.split(y, maxsplit)
@bundle()
def strip(x: Any, chars=None):
"""Removes the leading and trailing characters of x."""
return x.strip(chars)
@bundle()
def replace(x: Any, old: Any, new: Any, count: Any = -1):
"""Replaces all occurrences of substring y in string x with z."""
return x.replace(old, new, count)
@bundle()
def format(x: Any, *args, **kwargs):
"""Fills in a string template with content, str.format()."""
return x.format(*args, **kwargs)
@bundle()
def join(x: Any, *y: Any):
"""Joins a sequence y with different strs with x: "\n".join(["a", "b", "c"]) -> "a\nb\nc"."""
return x.join(y)
@bundle()
def node_getattr(obj: Node, attr: str):
"""This operator gets attr of obj."""
return getattr(obj, attr)
@bundle(
_process_inputs=False,
allow_external_dependencies=True,
)
def call(fun: Node, *args, **kwargs):
"""This operator calls the function `fun` with args (args_0, args_1, etc.) and kwargs. If there are no args or kwargs, i.e. call(fun=function_name), the function takes no input."""
# Run the function as it is
fun = fun._data
# Call the node with the input arguments
assert callable(fun), "The function must be callable."
output = fun(*args, **kwargs)
return output
@bundle()
def to_list(x: Any):
"""This converts x to a list."""
return list(x)
@bundle()
def make_list(*args):
"""This creates a list from the arguments."""
return list(args)
@bundle()
def to_dict(x: Any):
"""This converts x to a dictionary."""
return dict(x)
@bundle()
def make_dict(**kwargs):
"""This creates a dictionary from the keyword arguments."""
return kwargs
@bundle()
def to_set(x: Any):
"""This converts x to a set."""
return set(x)
@bundle()
def make_set(*args):
"""This creates a set from the arguments."""
return set(args)
@bundle()
def to_tuple(x: Any):
"""This converts x to a tuple."""
return tuple(x)
@bundle()
def make_tuple(*args):
"""This creates a tuple from the arguments."""
return tuple(args)
# dict operators
@bundle()
def keys(x: Dict):
"""Return the keys of a dictionary x as a list."""
if not isinstance(x, dict):
raise AttributeError(f"{type(x)} object has no attribute 'values'.")
return [k for k in x.keys()]
@bundle()
def values(x: Dict):
"""Return the values of a dictionary x as a list."""
if not isinstance(x, dict):
raise AttributeError(f"{type(x)} object has no attribute 'values'.")
return [k for k in x.values()]
# dict in-place operators
@bundle()
def dict_update(x: Dict, y: Dict):
"""Update the dictionary x with the dictionary y."""
x = copy.copy(x)
x.update(y)
return x
@bundle()
def dict_pop(x: Dict, key: Any):
"""Pop the key from the dictionary x."""
x = copy.copy(x)
x.pop(key)
return x
@bundle()
def dict_popitem(x: Dict):
"""Pop the last item from the dictionary x."""
x = copy.copy(x)
x.popitem()
return x
# list in-place operators
@bundle()
def list_append(x: Any, y: Any):
"""Append y to x."""
x = copy.copy(x)
x.append(y)
return x
@bundle()
def list_clear(x: Any):
"""Clear x."""
x = copy.copy(x)
x.clear()
return x
@bundle()
def list_extend(x: Any, y: Any):
"""Extend x with y."""
x = copy.copy(x)
x.extend(y)
return x
@bundle()
def list_insert(x: Any, index: Any, y: Any):
"""Insert y at index in x."""
x = copy.copy(x)
x.insert(index, y)
return x
@bundle()
def list_pop(x: Any, index: Any):
"""Pop the index from x."""
x = copy.copy(x)
x.pop(index)
return x
@bundle()
def list_remove(x: Any, y: Any):
"""Remove y from x."""
x = copy.copy(x)
x.remove(y)
return x
@bundle()
def list_reverse(x: Any):
"""Reverse x."""
x = copy.copy(x)
x.reverse()
return x
@bundle()
def list_sort(x: Any, key: Any = None, reverse: Any = False):
"""Sort x."""
x = copy.copy(x)
x.sort(key=key, reverse=reverse)
return x
# set in-place operators
@bundle()
def set_add(x: Any, y: Any):
"""Add y to x."""
x = copy.copy(x)
x.add(y)
return x
@bundle()
def set_clear(x: Any):
"""Clear x."""
x = copy.copy(x)
x.clear()
return x
@bundle()
def set_discard(x: Any, y: Any):
"""Discard y from x."""
x = copy.copy(x)
x.discard(y)
return x
@bundle()
def set_intersection_update(x: Any, y: Any):
"""Update x with the intersection of x and y."""
x = copy.copy(x)
x.intersection_update(y)
return x
@bundle()
def set_pop(x: Any):
"""Pop an element from x."""
x = copy.copy(x)
x.pop()
return x
@bundle()
def set_remove(x: Any, y: Any):
"""Remove y from x."""
x = copy.copy(x)
x.remove(y)
return x
@bundle()
def set_symmetric_difference_update(x: Any, y: Any):
"""Update x with the symmetric difference of x and y."""
x = copy.copy(x)
x.symmetric_difference_update(y)
return x
@bundle()
def set_update(x: Any, y: Any):
"""Update x with y."""
x = copy.copy(x)
x.update(y)
return x
@bundle()
def call_llm(system_prompt, *user_prompts, **kwargs):
"""Query the language model of system_prompt with user_prompts."""
messages = [{"role": "system", "content": system_prompt}]
for user_prompt in user_prompts:
messages.append({"role": "user", "content": user_prompt})
from opto.utils.llm import LLM
llm = LLM()
response = llm(messages=messages, **kwargs)
return response.choices[0].message.content