Source code for opto.trace.propagators.graph_propagator

from dataclasses import dataclass
from typing import Any, List, Dict, Tuple
from opto.trace.nodes import Node, MessageNode, ParameterNode, get_op_name, IDENTITY_OPERATORS, NodeVizStyleGuideColorful
from opto.trace.propagators.propagators import Propagator, AbstractFeedback
import heapq
from opto.trace.utils import sum_feedback

[docs] @dataclass class TraceGraph(AbstractFeedback): """Feedback container used by GraphPropagator.""" graph: List[Tuple[int,Node]] # a priority queue of nodes in the subgraph, ordered from roots to leaves user_feedback: Any def __add__(self, other): assert not ( self.user_feedback is None and other.user_feedback is None ), "One of the user feedback should not be None." if self.user_feedback is None or other.user_feedback is None: user_feedback = self.user_feedback if other.user_feedback is None else other.user_feedback else: # both are not None assert self.user_feedback == other.user_feedback, "user feedback should be the same for all children" user_feedback = self.user_feedback other_names = [id(n[1]) for n in other.graph] complement = [ x for x in self.graph if id(x[1]) not in other_names ] # `in` uses __eq__ which checks the value not the identity # TODO graph = [x for x in heapq.merge(complement, other.graph, key=lambda x: x[0])] return TraceGraph(graph=graph, user_feedback=user_feedback)
[docs] @classmethod def expand(cls, node: MessageNode): """ Return the subgraph within a MessageNode. """ assert isinstance(node, MessageNode) if isinstance(node.info['output'], MessageNode): # these are the nodes where we will collect the feedback roots = list(node.info['output'].parameter_dependencies) + \ list(node.info['output'].expandable_dependencies) + \ node.info['inputs']['args'] + [v for v in node.info['inputs']['kwargs'].values()] # remove old feedback, since we need to call backard again; we will restore it later old_feedback = {p: p._feedback for p in roots} for p in roots: p.zero_feedback() node.info['output'].backward('', retain_graph=True) subgraph = sum_feedback(roots) # restore the old feedback for p, feedback in old_feedback.items(): p._feedback = feedback else: subgraph = TraceGraph(graph=[], user_feedback=None) return subgraph
def __len__(self): return len(self.graph) def __iter__(self): return iter(self.graph) def _itemize(self, node): return (node.level, node)
[docs] def visualize(self, simple_visualization=True, reverse_plot=False, print_limit=100): from graphviz import Digraph nvsg = NodeVizStyleGuideColorful(print_limit=print_limit) queue = sorted(self.graph, key=lambda x: x[0]) # sort by level digraph = Digraph() if len(queue) == 1 and len(queue[0][1].parents) == 0: # This is a root. Nothing to propagate digraph.node(queue[0][1].py_name, **nvsg.get_attrs(queue[0][1])) return digraph # traverse the list to determine the relationship between nodes # and add edge if there's a relationship # we still use queue here because only lower level node can have a parent to higher level for level, node in queue: digraph.node(node.py_name, **nvsg.get_attrs(node)) # is there a faster way to determine child/parent relationship!? for parent in node.parents: if self._itemize(parent) in queue: # if there's a parent, add an edge, otherwise no need edge = (node.py_name, parent.py_name) if reverse_plot else (parent.py_name, node.py_name) digraph.edge(*edge) digraph.node(parent.py_name, **nvsg.get_attrs(parent)) return digraph
[docs] class GraphPropagator(Propagator): """A propagator that collects all the nodes seen in the path."""
[docs] def init_feedback(self, node, feedback: Any): return TraceGraph(graph=[(node.level, node)], user_feedback=feedback)
def _propagate(self, child: MessageNode): graph = [(p.level, p) for p in child.parents] # add the parents feedback = self.aggregate(child.feedback) + TraceGraph(graph=graph, user_feedback=None) assert isinstance(feedback, TraceGraph) # For including the external dependencies on parameters not visible # in the current graph level for param in child.hidden_dependencies: assert isinstance(param, ParameterNode) param._add_feedback(child, feedback) return {parent: feedback for parent in child.parents}
[docs] def aggregate(self, feedback: Dict[Node, List[TraceGraph]]): """Aggregate feedback from multiple children""" assert all(len(v) == 1 for v in feedback.values()) assert all(isinstance(v[0], TraceGraph) for v in feedback.values()) values = [sum(v) for v in feedback.values()] if len(values) == 0: return TraceGraph(graph=[], user_feedback=None) else: return sum(values)