[docs]@dataclassclassTraceGraph(AbstractFeedback):"""Feedback container used by GraphPropagator."""graph:List[Tuple[int,Node]]# a priority queue of nodes in the subgraph, ordered from roots to leavesuser_feedback:Any
def__add__(self,other):ifself.empty()andother.empty():returnTraceGraph(graph=[],user_feedback=None)# If one of them is not empty, one must contain the user feedbackassertnot(self.user_feedbackisNoneandother.user_feedbackisNone),"One of the user feedback should not be None."ifself.user_feedbackisNoneorother.user_feedbackisNone:user_feedback=(self.user_feedbackifother.user_feedbackisNoneelseother.user_feedback)else:# both are not Noneassert(self.user_feedback==other.user_feedback),"user feedback should be the same for all children"user_feedback=self.user_feedbackother_names=[id(n[1])forninother.graph]complement=[xforxinself.graphifid(x[1])notinother_names]# `in` uses __eq__ which checks the value not the identity # TODOgraph=[xforxinheapq.merge(complement,other.graph,key=lambdax:x[0])]returnTraceGraph(graph=graph,user_feedback=user_feedback)
[docs]@classmethoddefexpand(cls,node:MessageNode):"""Return the subgraph within a MessageNode."""assertisinstance(node,MessageNode)ifisinstance(node.info["output"],MessageNode):# these are the nodes where we will collect the feedbackroots=(list(node.info["output"].parameter_dependencies)+list(node.info["output"].expandable_dependencies)+node.info["inputs"]["args"]+[vforvinnode.info["inputs"]["kwargs"].values()])# remove old feedback, since we need to call backard again; we will restore it laterold_feedback={p:p._feedbackforpinroots}forpinroots:p.zero_feedback()node.info["output"].backward("",retain_graph=True)subgraph=sum_feedback(roots)# restore the old feedbackforp,feedbackinold_feedback.items():p._feedback=feedbackelse:subgraph=TraceGraph(graph=[],user_feedback=None)returnsubgraph
[docs]defvisualize(self,simple_visualization=True,reverse_plot=False,print_limit=100):fromgraphvizimportDigraphnvsg=NodeVizStyleGuideColorful(print_limit=print_limit)queue=sorted(self.graph,key=lambdax:x[0])# sort by leveldigraph=Digraph()if(len(queue)==1andlen(queue[0][1].parents)==0):# This is a root. Nothing to propagatedigraph.node(queue[0][1].py_name,**nvsg.get_attrs(queue[0][1]))returndigraph# traverse the list to determine the relationship between nodes# and add edge if there's a relationship# we still use queue here because only lower level node can have a parent to higher levelforlevel,nodeinqueue:digraph.node(node.py_name,**nvsg.get_attrs(node))# is there a faster way to determine child/parent relationship!?forparentinnode.parents:ifself._itemize(parent)inqueue:# if there's a parent, add an edge, otherwise no neededge=((node.py_name,parent.py_name)ifreverse_plotelse(parent.py_name,node.py_name))digraph.edge(*edge)digraph.node(parent.py_name,**nvsg.get_attrs(parent))returndigraph
[docs]classGraphPropagator(Propagator):"""A propagator that collects all the nodes seen in the path."""
def_propagate(self,child:MessageNode):graph=[(p.level,p)forpinchild.parents]# add the parentsfeedback=self.aggregate(child.feedback)+TraceGraph(graph=graph,user_feedback=None)assertisinstance(feedback,TraceGraph)# For including the external dependencies on parameters not visible# in the current graph levelforparaminchild.hidden_dependencies:assertisinstance(param,ParameterNode)param._add_feedback(child,feedback)return{parent:feedbackforparentinchild.parents}
[docs]defaggregate(self,feedback:Dict[Node,List[TraceGraph]]):"""Aggregate feedback from multiple children"""assertall(len(v)==1forvinfeedback.values())assertall(isinstance(v[0],TraceGraph)forvinfeedback.values())values=[sum(v)forvinfeedback.values()]iflen(values)==0:returnTraceGraph(graph=[],user_feedback=None)else:returnsum(values)