Source code for opto.trace.operators

from __future__ import annotations
import trace
from typing import TYPE_CHECKING, Any, Dict

if TYPE_CHECKING:  # to prevent circular import
    from opto.trace.nodes import Node
from opto.trace.bundle import bundle
import copy


@bundle()
def clone(x: Any):
    """This is a clone operator of x."""
    return copy.deepcopy(x)


[docs] def identity(x: Any): # identity(x) behaves the same as x.clone() return x.clone()
# Unary operators and functions @bundle() def pos(x: Any): """This is a pos operator of x.""" return +x @bundle() def neg(x: Any): """This is a neg operator of x.""" return -x @bundle() def abs(x: Any): """This is an abs operator of x.""" return abs(x) @bundle() def invert(x: Any): """This is an invert operator of x.""" return ~x @bundle() def round(x: Any, n: Any): """This is a round operator of x.""" return round(x, n) @bundle() def floor(x: Any): """This is a floor operator of x.""" import math return math.floor(x) @bundle() def ceil(x: Any): """This is a ceil operator of x.""" import math return math.ceil(x) @bundle() def trunc(x: Any): """This is a trunc operator of x.""" import math return math.trunc(x) # Normal arithmetic operators @bundle() def add(x: Any, y: Any): """This is an add operator of x and y.""" return x + y @bundle() def subtract(x: Any, y: Any): """This is a subtract operator of x and y.""" return x - y @bundle() def multiply(x: Any, y: Any): """This is a multiply operator of x and y.""" return x * y @bundle() def floor_divide(x: Any, y: Any): """This is a floor_divide operator of x and y.""" return x // y @bundle() def divide(x: Any, y: Any): """This is a divide operator of x and y.""" return x / y @bundle() def mod(x: Any, y: Any): """This is a mod operator of x and y.""" return x % y @bundle() def node_divmod(x: Any, y: Any): """This is a divmod operator of x and y.""" return divmod(x, y) @bundle() def power(x: Any, y: Any): """This is a power operator of x and y.""" return x**y @bundle() def lshift(x: Any, y: Any): """This is a lshift operator of x and y.""" return x << y @bundle() def rshift(x: Any, y: Any): """This is a rshift operator of x and y.""" return x >> y @bundle() def and_(x: Any, y: Any): """This is an and operator of x and y.""" return x & y @bundle() def or_(x: Any, y: Any): """This is an or operator of x and y.""" return x | y @bundle() def xor(x: Any, y: Any): """This is a xor operator of x and y.""" return x ^ y # Comparison methods @bundle() def lt(x: Any, y: Any): """This is a lt operator of x and y.""" return x < y @bundle() def le(x: Any, y: Any): """This is a le operator of x and y.""" return x <= y @bundle() def eq(x: Any, y: Any): """This is an eq operator of x and y.""" return x == y @bundle() def neq(x: Any, y: Any): """This is a not eq operator of x and y.""" return x != y @bundle() def ne(x: Any, y: Any): """This is a ne operator of x and y.""" return x != y @bundle() def ge(x: Any, y: Any): """This is a ge operator of x and y.""" return x >= y @bundle() def gt(x: Any, y: Any): """This is a gt operator of x and y.""" return x > y # logical operators @bundle() def cond(condition: Any, x: Any, y: Any): """This selects x if condition is True, otherwise y.""" x, y, condition = x, y, condition # This makes sure all data are read return x if condition else y @bundle() def not_(x: Any): """This is a not operator of x.""" return not x @bundle() def is_(x: Any, y: Any): """Whether x is equal to y.""" return x is y @bundle() def is_not(x: Any, y: Any): """Whether x is not equal to y.""" return x is not y @bundle() def in_(x: Any, y: Any): """Whether x is in y.""" return x in y @bundle() def not_in(x: Any, y: Any): """Whether x is not in y.""" return x not in y # Indexing and slicing @bundle() def getitem(x: Any, index: Any): """This is a getitem operator of x based on index.""" return x[index] @bundle() def pop(x: Any, index: Any): """This is a pop operator of x based on index.""" return x.pop(index) @bundle() def len_(x: Any): """This is a len operator of x.""" return len(x) # String operators @bundle() def ord_(x: Any): """The unicode number of a character.""" return ord(x) @bundle() def chr_(x: Any): """The character of a unicode number.""" return chr(x) @bundle() def concat(x: Any, y: Any): """This is a concatenation operator of x and y.""" return x + y @bundle() def lower(x: Any): """This makes all characters in x lower case.""" return x.lower() @bundle() def upper(x: Any): """This makes all characters in x upper case.""" return x.upper() @bundle() def title(x: Any): """This makes the first character to upper case and the rest to lower case.""" return x.title() @bundle() def swapcase(x: Any): """Swaps the case of all characters: uppercase character to lowercase and vice-versa.""" return x.swapcase() @bundle() def capitalize(x: Any): """Converts the first character of a string to uppercase.""" return x.capitalize() @bundle() def split(x: Any, y: Any, maxsplit: Any = -1): """Splits the string by finding a substring y in string x, return the first part and second part of string x without y.""" return x.split(y, maxsplit) @bundle() def strip(x: Any, chars=None): """Removes the leading and trailing characters of x.""" return x.strip(chars) @bundle() def replace(x: Any, old: Any, new: Any, count: Any = -1): """Replaces all occurrences of substring y in string x with z.""" return x.replace(old, new, count) @bundle() def format(x: Any, *args, **kwargs): """Fills in a string template with content, str.format().""" return x.format(*args, **kwargs) @bundle() def join(x: Any, *y: Any): """Joins a sequence y with different strs with x: "\n".join(["a", "b", "c"]) -> "a\nb\nc".""" return x.join(y) @bundle() def node_getattr(obj: Node, attr: str): """This operator gets attr of obj.""" return getattr(obj, attr) @bundle( _process_inputs=False, allow_external_dependencies=True, ) def call(fun: Node, *args, **kwargs): """This operator calls the function `fun` with args (args_0, args_1, etc.) and kwargs. If there are no args or kwargs, i.e. call(fun=function_name), the function takes no input.""" # Run the function as it is fun = fun._data # Call the node with the input arguments assert callable(fun), "The function must be callable." output = fun(*args, **kwargs) return output @bundle() def to_list(x: Any): """This converts x to a list.""" return list(x) @bundle() def make_list(*args): """This creates a list from the arguments.""" return list(args) @bundle() def to_dict(x: Any): """This converts x to a dictionary.""" return dict(x) @bundle() def make_dict(**kwargs): """This creates a dictionary from the keyword arguments.""" return kwargs @bundle() def to_set(x: Any): """This converts x to a set.""" return set(x) @bundle() def make_set(*args): """This creates a set from the arguments.""" return set(args) @bundle() def to_tuple(x: Any): """This converts x to a tuple.""" return tuple(x) @bundle() def make_tuple(*args): """This creates a tuple from the arguments.""" return tuple(args) # dict operators @bundle() def keys(x: Dict): """Return the keys of a dictionary x as a list.""" if not isinstance(x, dict): raise AttributeError(f"{type(x)} object has no attribute 'values'.") return [k for k in x.keys()] @bundle() def values(x: Dict): """Return the values of a dictionary x as a list.""" if not isinstance(x, dict): raise AttributeError(f"{type(x)} object has no attribute 'values'.") return [k for k in x.values()] # dict in-place operators @bundle() def dict_update(x: Dict, y: Dict): """Update the dictionary x with the dictionary y.""" x = copy.copy(x) x.update(y) return x @bundle() def dict_pop(x: Dict, key: Any): """Pop the key from the dictionary x.""" x = copy.copy(x) x.pop(key) return x @bundle() def dict_popitem(x: Dict): """Pop the last item from the dictionary x.""" x = copy.copy(x) x.popitem() return x # list in-place operators @bundle() def list_append(x: Any, y: Any): """Append y to x.""" x = copy.copy(x) x.append(y) return x @bundle() def list_clear(x: Any): """Clear x.""" x = copy.copy(x) x.clear() return x @bundle() def list_extend(x: Any, y: Any): """Extend x with y.""" x = copy.copy(x) x.extend(y) return x @bundle() def list_insert(x: Any, index: Any, y: Any): """Insert y at index in x.""" x = copy.copy(x) x.insert(index, y) return x @bundle() def list_pop(x: Any, index: Any): """Pop the index from x.""" x = copy.copy(x) x.pop(index) return x @bundle() def list_remove(x: Any, y: Any): """Remove y from x.""" x = copy.copy(x) x.remove(y) return x @bundle() def list_reverse(x: Any): """Reverse x.""" x = copy.copy(x) x.reverse() return x @bundle() def list_sort(x: Any, key: Any = None, reverse: Any = False): """Sort x.""" x = copy.copy(x) x.sort(key=key, reverse=reverse) return x # set in-place operators @bundle() def set_add(x: Any, y: Any): """Add y to x.""" x = copy.copy(x) x.add(y) return x @bundle() def set_clear(x: Any): """Clear x.""" x = copy.copy(x) x.clear() return x @bundle() def set_discard(x: Any, y: Any): """Discard y from x.""" x = copy.copy(x) x.discard(y) return x @bundle() def set_intersection_update(x: Any, y: Any): """Update x with the intersection of x and y.""" x = copy.copy(x) x.intersection_update(y) return x @bundle() def set_pop(x: Any): """Pop an element from x.""" x = copy.copy(x) x.pop() return x @bundle() def set_remove(x: Any, y: Any): """Remove y from x.""" x = copy.copy(x) x.remove(y) return x @bundle() def set_symmetric_difference_update(x: Any, y: Any): """Update x with the symmetric difference of x and y.""" x = copy.copy(x) x.symmetric_difference_update(y) return x @bundle() def set_update(x: Any, y: Any): """Update x with y.""" x = copy.copy(x) x.update(y) return x @bundle() def call_llm(system_prompt, *user_prompts, **kwargs): """Query the language model of system_prompt with user_prompts.""" messages = [{"role": "system", "content": system_prompt}] for user_prompt in user_prompts: messages.append({"role": "user", "content": user_prompt}) from opto.utils.llm import AutoGenLLM llm = AutoGenLLM() response = llm(messages=messages, **kwargs) return response.choices[0].message.content