Basics#

Node and MessageNode#

trace is a computational graph framework for tracing and optimizing codes. Its core data structure is the “node” container of python objects. To create a node, use node method, which creates a Node object. To access the content of a node, use the data attribute.

!pip install trace-opt
from opto.trace import node, GRAPH

def print_node(node):
    print(node)
    print(f"parents: {[p.name for p in node.parents]}")
x = node(1)  # node of int
print("node of int", x.data)
x = node("string")  # node of str
print(x.data)
x = node([1, 2, 3])  # node of list
print(x.data)
x = node({"a": 1, "b": 2})  # node of dict
print(x.data)


class Foo:
    def __init__(self, x):
        self.x = x
        self.secret = "secret"

    def print(self, val):
        print(val)


x = node(Foo("foo"))  # node of a class instance
print(x.data)
node of int 1
string
[1, 2, 3]
{'a': 1, 'b': 2}
<__main__.Foo object at 0x1062ad8d0>

When a computation is performed using the contents of nodes, the result is also a node. This allows for the creation of a computation graph. The computation graph is a directed acyclic graph where the edges indicate the data dependencies.

Nodes that are defined manually can be marked as trainable by setting their trainable attribute to True; such nodes are a subclass of Node called ParameterNode. Nodes that are created automatically as a result of computations are a different subclass of Node called MessageNode.

Nodes can be copied. This can be done in two ways with clone or detach

# clone returns a MessageNode whose parent is the original node
x_clone = x.clone()
assert x in x_clone.parents
assert x_clone.data != x.data
assert x_clone.data.x == x.data.x
print(x_clone.data)
# detach returns a new Node which is not connected to the original node
x_detach = x.detach()
assert len(x_detach.parents) == 0
assert x_detach.data != x.data
assert x_detach.data.x == x.data.x
<__main__.Foo object at 0x00000233CF901F40>

trace overloads python’s magic methods that gives return value explicitly (such as __add__), except logical operations such as __bool__ and setters. (The comparison magic methods instead compare nodes according to the computation graph that will be explained later, rather than comparing the data.)

When nodes are used with these magic methods, the output would be a MessageNode, which is a subclass of Node that has the inputs of the method as the parents. The attribute description of a MessageNode documents the method’s function.

# Basic arithmetic operations
x = node(1, name="node_x")
y = node(3, name="node_y")
z = x / y
z2 = x / 3  # the int 3 would be converted to a node automatically
print(z)
print_node(z)
print("\n")

# Index a node
dict_node = node({"a": 1, "b": 2}, name="dict_node")
a = dict_node["a"]
print_node(a)
print("len(dict_node) =", dict_node.len())

print("\n")

# Getting class attribute and calling class method
x = node(Foo("foo"))
x.call("print", "hello world")
print_node(x.getattr("secret"))
MessageNode: (divide:0, dtype=<class 'float'>, data=0.3333333333333333)
MessageNode: (divide:0, dtype=<class 'float'>, data=0.3333333333333333)
parents: ['node_x:0', 'node_y:0']


MessageNode: (getitem:0, dtype=<class 'int'>, data=1)
parents: ['dict_node:0', 'str:1']
len(dict_node) = MessageNode: (len_:0, dtype=<class 'int'>, data=2)


Node: (str:3, dtype=<class 'str'>, data=hello world)
MessageNode: (node_getattr:1, dtype=<class 'str'>, data=secret)
parents: ['Foo:1', 'str:4']

Operations on Nodes#

For equivalence relations between nodes, we follow the PyTorch convention.

In order to work with Python’s control flow statements like if and while, the result of the comparison (a boolean value) is not a node and therefore is not traced.

x = node(True)
if x:
    print("True")

x = node([1, 2, 3])
print(1 in x)

result = 1 in x # result is not a node
try:
    result.backward()
except:
    print("result is not a node, therefore, we cannot call backward() on it, but we can use `x.eq(1)` function to trace the comparison.")

# In order to trace the comparison, we need to use `.eq` method
result = x.eq(1)
result.backward(visualize=True, print_limit=15)
True
True
result is not a node, therefore, we cannot call backward() on it, but we can use `x.eq(1)` function to trace the comparison.
../_images/c9d0d9fb71f88077d4b6f06b696f170e8a75867e6f28080c055d22c9a91e524b.svg

If the two nodes contain the same value, they are considered equal when we use in operator. If you want to check if a node object is inside a list, use the contain function.

from opto.trace.utils import contain

x = node(1)
y = [x, node(2), node(3)]
y2 = [node(1), node(2), node(3)]

print("x is in y", contain(y, x))
print("x is not in y2", contain(y2, x))

# x is in y and y2 if we use `in` operator
print("When we use `in` operator, x is in y", x in y)
print("When we use `in` operator, x is also in y2", x in y2)
x is in y True
x is not in y2 False
When we use `in` operator, x is in y True
When we use `in` operator, x is also in y2 True

Warning

When using a node with a logical operator like and, or, not, the output does not always have the same behavior – since the result is dependent on how Python evvaluates the expression.

Hide code cell content
x = node(True)

y = True and x  # Node
print("True and x:", y)
y = x and True  # True
print("x and True:", y)
y = node(True) and x  # Node
print("node(True) and x:", y)
y = x and node(True)  # Node
print("x and node(True):", y)

print('\n')

y = False and x  # False
print("False and x:", y)
y = x and False  # False
print("x and False:", y)
y = node(False) and x  # Node
print("node(False) and x:", y)
y = x and node(False)  # Node
print("x and node(False):", y)

print('\n')

x = node(False)

y = True and x  # Node
print("True and x:", y)
y = x and True  # Node
print("x and True:", y)
y = node(True) and x  # Node
print("node(True) and x:", y)
y = x and node(True)  # Node
print("x and node(True):", y)

print('\n')

y = False and x  # False
print("False and x:", y)
y = x and False  # Node
print("x and False:", y)
y = node(False) and x  # Node
print("node(False) and x:", y)
y = x and node(False)  # Node
print("x and node(False):", y)
True and x: Node: (bool:18, dtype=<class 'bool'>, data=True)
x and True: True
node(True) and x: Node: (bool:18, dtype=<class 'bool'>, data=True)
x and node(True): Node: (bool:20, dtype=<class 'bool'>, data=True)


False and x: False
x and False: False
node(False) and x: Node: (bool:21, dtype=<class 'bool'>, data=False)
x and node(False): Node: (bool:22, dtype=<class 'bool'>, data=False)


True and x: Node: (bool:23, dtype=<class 'bool'>, data=False)
x and True: Node: (bool:23, dtype=<class 'bool'>, data=False)
node(True) and x: Node: (bool:23, dtype=<class 'bool'>, data=False)
x and node(True): Node: (bool:23, dtype=<class 'bool'>, data=False)


False and x: False
x and False: Node: (bool:23, dtype=<class 'bool'>, data=False)
node(False) and x: Node: (bool:25, dtype=<class 'bool'>, data=False)
x and node(False): Node: (bool:23, dtype=<class 'bool'>, data=False)

Nodes can be used to encapsulate any python object, including functions. Here are a few examples.

def fun(x):
    return x + 1


fun_node = node(fun)
y = fun_node(node(1))
print(f"output: {y}\nparents {[(p.name, p.data) for p in y.parents]}")
print("\n\n")


class Foo:

    def __init__(self):
        self.node = node(1)
        self.non_node = 2

    def trace_fun(self):
        return self.node * 2

    def non_trace_fun(self):
        return self.non_node * 2


foo = node(Foo())

try:
    foo.node
    foo.trace_fun()
except AttributeError:
    print("The attribute of the wrapped object cannot be directly accessed. Instead use getattr() or call()")


attr = foo.getattr("node")
print(f"foo_node: {attr}\nparents {[(p.name, p.data) for p in attr.parents]}")


attr = foo.getattr("non_node")
print(f"non_node: {attr}\nparents {[(p.name, p.data) for p in attr.parents]}")


fun = foo.getattr("non_trace_fun")
y = fun()
print(f"output: {y}\nparents {[(p.name, p.data) for p in y.parents]}")

try:
    fun = foo.getattr("trace_fun")
    y = fun()
except AssertionError as e:
    print(e)

y = foo.call("non_trace_fun")
print(f"output: {y}\nparents {[(p.name, p.data) for p in y.parents]}")

try:
    y = foo.call("trace_fun")
except AssertionError as e:
    print(e)
output: MessageNode: (call:1, dtype=<class 'int'>, data=2)
parents [('function:0', <function fun at 0x00000233CF64AC10>), ('int:3', 1)]



The attribute of the wrapped object cannot be directly accessed. Instead use getattr() or call()
foo_node: MessageNode: (node_getattr:2, dtype=<class 'int'>, data=1)
parents [('Foo:2', <__main__.Foo object at 0x00000233CF730FA0>), ('str:5', 'node')]
non_node: MessageNode: (node_getattr:3, dtype=<class 'int'>, data=2)
parents [('Foo:2', <__main__.Foo object at 0x00000233CF730FA0>), ('str:6', 'non_node')]
output: MessageNode: (call:2, dtype=<class 'int'>, data=4)
parents [('node_getattr:4', <bound method Foo.non_trace_fun of <__main__.Foo object at 0x00000233CF730FA0>>)]
output: MessageNode: (call:4, dtype=<class 'int'>, data=4)
parents [('node_getattr:6', <bound method Foo.non_trace_fun of <__main__.Foo object at 0x00000233CF730FA0>>)]

Use Bundle to Writing Custom Node Operators#

In addition to magic methods, we can use bundle to write custom methods that are traceable. When decorating a method with bundle, it needs a description of the method. It has a format of [method_name] description. bundle will automatically add all nodes whose data attribute is used within the function as the parents of the output MessageNode.

Given a function fun, the decorated function by default will unpack all the inputs (i.e. it unpacks all the data inside nodes), send them to fun, and then creates a MessageNode to wrap the output of fun which has parents containing all the nodes used in this operation.

from opto.trace import bundle, GRAPH
from opto.trace.nodes import Node

GRAPH.clear()

@bundle()
def add(x):
    """
    Add 1 to input x
    """
    return x + 1


x = node(1, name="node_x")
z = add(x)
print_node(z)
print("\n")


@bundle()
def add(x, y):
    """
    Add input x and input y
    """
    return x + y


x = node(1, name="node_x")
y = node(2, name="node_y")
z = add(x, y)
print_node(z)
print("\n")

# The output is a node of a tuple of two nodes


@bundle()
def pass_through(x, y):
    """
    No operation, just return inputs
    """
    return x, y


x = node(1, name="node_x")
y = node(2, name="node_y")
z = pass_through(x, y)
print(z)

assert isinstance(z, Node)
assert isinstance(z.data, tuple)
assert len(z.data) == 2
print("\n")
MessageNode: (add:0, dtype=<class 'int'>, data=2)
parents: ['node_x:0']


MessageNode: (add:1, dtype=<class 'int'>, data=3)
parents: ['node_x:1', 'node_y:0']


MessageNode: (pass_through:0, dtype=<class 'tuple'>, data=(1, 2))

Visualize Graph#

The graph of nodes can be visualized by calling backward method of a node. (Later we will cover how backward also sends feedback across the graph).

from opto.trace.nodes import GRAPH

GRAPH.clear()  # to remove all the nodes
x = node(1, name="node_x")
y = node(2, name="node_y")
a = x + y
b = x + 1
final = a + b
final.backward(visualize=True)
../_images/226906a3121e0e916d2140ebb66bddb975b1f362a761bc7ab4ced9f3862a44d3.svg
GRAPH.clear()
x = node(True)
one = node(1)
zero = node(0)
print(x, one, zero)
# Logical operations are not traceable
y = one if x.data else zero
y.backward(visualize=True)
Node: (bool:0, dtype=<class 'bool'>, data=True) Node: (int:0, dtype=<class 'int'>, data=1) Node: (int:1, dtype=<class 'int'>, data=0)
../_images/6c3163b951437a6a857db832c77342f8ade179b6a5a7478b70ce6bd9177216dc.svg
# This is traceable
@bundle(allow_external_dependencies=True)
def fun(x):
    """
    Return one if input x is True, otherwise return zero
    """
    return one.data if x else zero.data


y = fun(x)
y.backward(visualize=True)
../_images/ffef07bba6a27076529be2ed84e04d0f2b46e22c0384099700204c700978c4ea.svg

Broadcasting#

Using apply_op, we can broadcast node operators to a container of nodes. A container of nodes are either list, tuple, dict, or subclass of an abstract class BaseModule. apply_op recursively applies the operator to all nodes in the container.

from opto.trace import apply_op, node, NodeContainer
from opto.trace import operators as ops

import copy

# Using list as a node container
x = [node(1), node(2), 1]
y = [node(3), node(4), 2]
z = copy.deepcopy(x)
z = apply_op(ops.add, z, x, y)
print("x", [x[0].data, x[1].data, x[2]])
print("y", [y[0].data, y[1].data, y[2]])
print("Elements in z should be added, except for the last one. Value: ", [z[0].data, z[1].data, z[2]])


# Using list as a node container
x = dict(a=node(1), b=0)
y = dict(a=node(3), b=0)
z = copy.deepcopy(x)
z = apply_op(ops.add, z, x, y)
print(f"{x['a'].data}+{y['a'].data}={z['a'].data}")
print(f"{x['b']}=={y['b']}=={z['b']}")

# Using a custom class as a node container


class Foo(NodeContainer):
    def __init__(self, x):
        self.x = node(x)
        self.y = [node(1), node(2)]
        self.z = 1


x = Foo("x")
y = Foo("y")
x_plus_y = Foo("template")
x_plus_y = apply_op(ops.add, x_plus_y, x, y)
print("x_plus_y.x should be added. Value: ", x_plus_y.x.data)
print("x_plus_y.y should be added. Value: ", [n.data for n in x_plus_y.y])
print("x_plus_y.z should be not added, just 1. Value: ", x_plus_y.z)
x [1, 2, 1]
y [3, 4, 2]
Elements in z should be added, except for the last one. Value:  [4, 6, 1]
1+3=4
0==0==0
x_plus_y.x should be added. Value:  xy
x_plus_y.y should be added. Value:  [2, 4]
x_plus_y.z should be not added, just 1. Value:  1

Nodes and Python Data Structure#

We can create a node over Python data structure like dictionary, tuple, set, or list. We automatically handle the iteration and you can wrap a node around any data structure and use them like normal python objects.

from opto.trace import node

args = node({"arg1", "arg2"}, trainable=False)
for a in args:
    print(a)

a.backward(visualize=True)
MessageNode: (getitem:0, dtype=<class 'str'>, data=arg2)
MessageNode: (getitem:1, dtype=<class 'str'>, data=arg1)
../_images/284f2f1fa6fbec5c5aa468f3442c993746a1e1e6fba3c08e50c54be9b601a525.svg
parms = node([1, 2], trainable=False)
args = node(["arg1", "arg2"], trainable=False)

for a, p in zip(args, parms):
    print(a, p)

p.backward(visualize=True)
MessageNode: (getitem:2, dtype=<class 'str'>, data=arg1) MessageNode: (getitem:3, dtype=<class 'int'>, data=1)
MessageNode: (getitem:4, dtype=<class 'str'>, data=arg2) MessageNode: (getitem:5, dtype=<class 'int'>, data=2)
../_images/a9d0643deeee6ca3d1f7967ab387c073a0cc80082fbcae4087eaafb02ab9f51f.svg