Source code for opto.trace.iterators

from opto.trace.nodes import node, Node, ExceptionNode
from typing import Any

# TODO: remove unused import
# from opto.trace.bundle import bundle
import opto.trace.operators as ops
from opto.trace.errors import ExecutionError


# List[Nodes], Node[List]
[docs] def iterate(x: Any): """Return an iterator object for node of list, tuple, set, or dict.""" if not isinstance(x, Node): x = node(x) if issubclass(x.type, list) or issubclass(x.type, tuple) or issubclass(x.type, str): return SeqIterable(x) elif issubclass(x.type, set): converted_list = ops.to_list(x) return SeqIterable(converted_list) elif issubclass(x.type, dict): return SeqIterable(x.keys()) else: raw_traceback = "TypeError: Cannot unpack non-iterable {} object".format( type(x._data) ) ex = TypeError(raw_traceback) e = ExceptionNode( ex, inputs=[x], info={ "traceback": raw_traceback, }, ) raise ExecutionError(e)
# List, Tuple, Set share an Iterable
[docs] class SeqIterable:
[docs] def __init__(self, wrapped_list): assert isinstance(wrapped_list, Node) self._index = 0 self.wrapped_list = wrapped_list
def __iter__(self): self._index = 0 return self def __next__(self): if self._index < len(self.wrapped_list._data): result = self.wrapped_list[self._index] self._index += 1 assert isinstance(result, Node) assert self.wrapped_list in result.parents return result else: raise StopIteration
[docs] class DictIterable:
[docs] def __init__(self, wrapped_dict): assert isinstance(wrapped_dict, Node) self._index = 0 self.wrapped_dict = wrapped_dict self.keys = ops.keys(wrapped_dict)
def __iter__(self): self._index = 0 return self def __next__(self): if self._index < len(self.keys): key = self.keys[self._index] result = (key, self.wrapped_dict[key]) self._index += 1 assert self.wrapped_dict in result[1].parents return result else: raise StopIteration