import copy
from opto.trace.containers import NodeContainer
from opto.trace.nodes import Node
[docs]
def recursive_conversion(true_func, false_func):
""" Recursively apply true_func to the nodes and false_func to the rest of
the objects in a container of nodes. Container of nodes are tuple, list,
dict, set, and NodeContainer.
Args:
true_func (callable): the function to be applied to the nodes.
false_func (callable): the function to be applied to the rest of the objects.
"""
def func(obj):
if isinstance(obj, Node): # base case
return true_func(obj)
elif isinstance(obj, tuple):
return tuple(func(x) for x in obj)
elif isinstance(obj, list):
return [func(x) for x in obj]
elif isinstance(obj, dict):
return {k: func(v) for k, v in obj.items()}
elif isinstance(obj, set):
return {func(x) for x in obj}
elif isinstance(obj, NodeContainer):
output = copy.copy(obj)
for k, v in obj.__dict__.items():
setattr(output, k, func(v))
return output
else:
return false_func(obj)
return func
# TODO to test it and clean up the code
[docs]
def apply_op(op, output, *args, **kwargs):
"""A broadcasting operation that applies an op to container of Nodes.
Args:
op (callable): the operator to be applied.
output (Any): the container to be updated.
*args (Any): the positional inputs of the operator.
**kwargs (Any): the keyword inputs of the operator.
"""
inputs = list(args) + list(kwargs.values())
containers = [x for x in inputs if not isinstance(x, Node)]
if len(containers) == 0: # all inputs are Nodes, we just apply op
return op(*args, **kwargs)
# # there is at least one container
# output = copy.deepcopy(containers[0]) # this would be used as the template of the output
def admissible_type(x, base):
return type(x) == type(base) or isinstance(x, Node)
assert all(admissible_type(x, output) for x in inputs) # All inputs are either Nodes or the same type as output
if isinstance(output, list) or isinstance(output, tuple):
assert all(
isinstance(x, Node) or len(output) == len(x) for x in inputs
), f"output {output} and inputs {inputs} are of different lengths."
for k in range(len(output)):
_args = [x if isinstance(x, Node) else x[k] for x in args]
_kwargs = {kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()}
output[k] = apply_op(op, output[k], *_args, **_kwargs)
if isinstance(output, tuple):
output = tuple(output)
elif isinstance(output, dict):
for k, v in output.items():
_args = [x if isinstance(x, Node) else x[k] for x in args]
_kwargs = {kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()}
output[k] = apply_op(op, output[k], *_args, **_kwargs)
elif isinstance(output, NodeContainer): # this is a NodeContainer object instance
for k, v in output.__dict__.items():
_args = [x if isinstance(x, Node) else getattr(x, k) for x in args]
_kwargs = {kk: vv if isinstance(v, Node) else getattr(vv, k) for kk, vv in kwargs.items()}
new_v = apply_op(op, v, *_args, **_kwargs)
setattr(output, k, new_v)
else:
pass
return output