from dataclasses import dataclass
from typing import Any, List, Dict, Tuple
from opto.trace.nodes import (
Node,
MessageNode,
ParameterNode,
get_op_name,
IDENTITY_OPERATORS,
NodeVizStyleGuideColorful,
)
from opto.trace.propagators.propagators import Propagator, AbstractFeedback
import heapq
from opto.trace.utils import sum_feedback, contain
[docs]
@dataclass
class TraceGraph(AbstractFeedback):
"""Feedback container used by GraphPropagator."""
graph: List[
Tuple[int, Node]
] # a priority queue of nodes in the subgraph, ordered from roots to leaves
user_feedback: Any
[docs]
def empty(self):
return len(self.graph) == 0 and self.user_feedback is None
def __add__(self, other):
if self.empty() and other.empty():
return TraceGraph(graph=[], user_feedback=None)
# If one of them is not empty, one must contain the user feedback
assert not (
self.user_feedback is None and other.user_feedback is None
), "One of the user feedback should not be None."
if self.user_feedback is None or other.user_feedback is None:
user_feedback = (
self.user_feedback
if other.user_feedback is None
else other.user_feedback
)
else: # both are not None
assert (
self.user_feedback == other.user_feedback
), "user feedback should be the same for all children"
user_feedback = self.user_feedback
other_names = [id(n[1]) for n in other.graph]
complement = [
x for x in self.graph if id(x[1]) not in other_names
] # `in` uses __eq__ which checks the value not the identity # TODO
graph = [x for x in heapq.merge(complement, other.graph, key=lambda x: x[0])]
return TraceGraph(graph=graph, user_feedback=user_feedback)
[docs]
@classmethod
def expand(cls, node: MessageNode):
"""Return the subgraph within a MessageNode."""
assert isinstance(node, MessageNode)
if isinstance(node.info["output"], MessageNode):
# these are the nodes where we will collect the feedback
roots = (
list(node.info["output"].parameter_dependencies)
+ list(node.info["output"].expandable_dependencies)
+ node.info["inputs"]["args"]
+ [v for v in node.info["inputs"]["kwargs"].values()]
)
# remove old feedback, since we need to call backard again; we will restore it later
old_feedback = {p: p._feedback for p in roots}
for p in roots:
p.zero_feedback()
node.info["output"].backward("", retain_graph=True)
subgraph = sum_feedback(roots)
# restore the old feedback
for p, feedback in old_feedback.items():
p._feedback = feedback
else:
subgraph = TraceGraph(graph=[], user_feedback=None)
return subgraph
def __len__(self):
return len(self.graph)
def __iter__(self):
return iter(self.graph)
def _itemize(self, node):
return (node.level, node)
[docs]
def visualize(self, simple_visualization=True, reverse_plot=False, print_limit=100):
from graphviz import Digraph
nvsg = NodeVizStyleGuideColorful(print_limit=print_limit)
queue = sorted(self.graph, key=lambda x: x[0]) # sort by level
digraph = Digraph()
if (
len(queue) == 1 and len(queue[0][1].parents) == 0
): # This is a root. Nothing to propagate
digraph.node(queue[0][1].py_name, **nvsg.get_attrs(queue[0][1]))
return digraph
# traverse the list to determine the relationship between nodes
# and add edge if there's a relationship
# we still use queue here because only lower level node can have a parent to higher level
nodes_in_queue = set([node for level, node in queue])
for level, node in queue:
digraph.node(node.py_name, **nvsg.get_attrs(node))
# is there a faster way to determine child/parent relationship!?
if all( contain(nodes_in_queue, parent) for parent in node.parents):
for parent in node.parents:
# if there's a parent, add an edge, otherwise no need
edge = (
(node.py_name, parent.py_name)
if reverse_plot
else (parent.py_name, node.py_name)
)
digraph.edge(*edge)
digraph.node(parent.py_name, **nvsg.get_attrs(parent))
return digraph
[docs]
class GraphPropagator(Propagator):
"""A propagator that collects all the nodes seen in the path."""
[docs]
def init_feedback(self, node, feedback: Any):
return TraceGraph(graph=[(node.level, node)], user_feedback=feedback)
def _propagate(self, child: MessageNode):
graph = [(p.level, p) for p in child.parents] # add the parents
feedback = self.aggregate(child.feedback) + TraceGraph(
graph=graph, user_feedback=None
)
assert isinstance(feedback, TraceGraph)
# For including the external dependencies on parameters not visible
# in the current graph level
for param in child.hidden_dependencies:
assert isinstance(param, ParameterNode)
param._add_feedback(child, feedback)
return {parent: feedback for parent in child.parents}
[docs]
def aggregate(self, feedback: Dict[Node, List[TraceGraph]]):
"""Aggregate feedback from multiple children"""
assert all(len(v) == 1 for v in feedback.values())
assert all(isinstance(v[0], TraceGraph) for v in feedback.values())
values = [sum(v) for v in feedback.values()]
if len(values) == 0:
return TraceGraph(graph=[], user_feedback=None)
else:
return sum(values)