Source code for opto.trace.propagators.propagators
from typing import Any, List, Dict, Tuple
from opto.trace.nodes import Node, MessageNode, get_op_name
[docs]
class AbstractPropagator:
def __call__(self, child: MessageNode):
"""Calling this method would propagte the feedback from the child to the parents."""
assert isinstance(child, MessageNode)
assert all(
[len(f) <= 1 for f in child.feedback.values()]
) # All MessageNode feedback should be at most length 1
propagated_feedback = self.propagate(child)
# Check propagated feedback has the right format
# It should be a dictionary with the parents as keys and the feedback as values
assert isinstance(propagated_feedback, dict)
assert all((p in propagated_feedback for p in child.parents))
return propagated_feedback
[docs]
def propagate(self, child: MessageNode) -> Dict[Node, Any]:
"""Compute propagated feedback to node.parents of a node. Return a dict where
the keys are the parents and the values are the
propagated feedback.
"""
raise NotImplementedError
[docs]
class AbstractFeedback:
"""Feedback container used by propagators. It needs to support addition."""
def __add__(self, other):
raise NotImplementedError
def __radd__(self, other):
if other == 0: # for support sum
return self
else:
return self.__add__(other)
[docs]
class Propagator(AbstractPropagator):
def __init__(self):
self.override = dict() # key: operator name: data: override propagate function
[docs]
def register(self, operator_name, propagate_function):
self.override[operator_name] = propagate_function
[docs]
def propagate(self, child: MessageNode) -> Dict[Node, Any]:
operator_name = get_op_name(child.description)
if operator_name in self.override:
return self.override[operator_name](child)
else:
return self._propagate(child)
[docs]
def init_feedback(self, node: Node, feedback: Any):
"""
Given raw feedback, create the feedback object that will be propagated recursively.
"""
raise NotImplementedError
def _propagate(self, child: MessageNode) -> Dict[Node, Any]:
"""Compute propagated feedback to node.parents based on
node.description, node.data, and node.feedback. Return a dict where
the keys are the parents and the values are the
propagated feedback.
"""
raise NotImplementedError
# Note:
# if len(feedback) > 1, it means there are two or more child nodes from this node,
# we might need to perform a "merge" feedback action
# # TODO test
[docs]
class SumPropagator(Propagator):
[docs]
def init_feedback(self, feedback: Any):
return feedback
def _propagate(self, child: MessageNode):
if "user" in child.feedback:
assert len(child.feedback) == 1, "user feedback should be the only feedback"
assert len(child.feedback["user"]) == 1
feedback = child.feedback["user"][0]
else:
# Simply sum the feedback
feedback_list = [v[0] for k, v in child.feedback.items()]
assert len(feedback_list) > 0
assert all([type(feedback_list[0]) == type(f) for f in feedback_list]), "error in propagate"
if isinstance(feedback_list[0], str):
feedback = "".join(feedback_list)
else:
feedback = sum(feedback_list)
return {parent: feedback for parent in child.parents}