Source code for opto.trace.containers

import inspect
from collections import UserDict, UserList
from opto.trace.nodes import ParameterNode
import functools

[docs] class NodeContainer: """ An identifier for a container of nodes.""" ...
[docs] def trainable_method(method): from opto.trace.bundle import FunModule if isinstance(method, FunModule): return method.trainable return False
[docs] class ParameterContainer(NodeContainer): """ A container of parameter nodes. """
[docs] def parameters(self): """ Return a flattned list of all the parameters in the model's parameters_dict, useful for optimization.""" parameters = [] for k, v in self.parameters_dict().items(): if isinstance(v, ParameterNode): parameters.append(v) elif isinstance(v, ParameterContainer): parameters.extend(v.parameters()) else: raise ValueError("The model contains an unknown parameter type.") return parameters
[docs] def parameters_dict(self): """ Return a dictionary of all the parameters in the model, including both trainable and non-trainable parameters. The dict contains ParameterNodes or ParameterContainers. """ parameters = {} for name, attr in inspect.getmembers(self): if isinstance(attr, functools.partial): # this is a class method method = attr.func.__self__ if trainable_method(method): parameters[name] = method.parameter elif trainable_method(attr): # method attribute parameters[name] = attr.parameter elif isinstance(attr, ParameterNode): parameters[name] = attr elif isinstance(attr, ParameterContainer): parameters[name] = attr assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values()) return parameters # include both trainable and non-trainable parameters
[docs] class Seq(UserList, ParameterContainer): """ Seq is defined as having a length and an index. Python's list/tuple will be converted to Seq """ def __init__(self, *args): if len(args) == 1 and hasattr(args[0], "__len__") and hasattr(args[0], "__getitem__"): seq = args[0] else: seq = args super().__init__(initlist=seq)
[docs] def parameters_dict(self): """ Return a dictionary of all the parameters in the model, including both trainable and non-trainable parameters. The dict contains ParameterNodes or ParameterContainers. """ parameters = {} for attr in self.data: if isinstance(attr, ParameterNode): parameters[attr.name] = attr elif isinstance(attr, ParameterContainer): parameters[str(attr)] = attr # TODO: what is the name of the container? assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values()) return parameters
[docs] class Map(UserDict, ParameterContainer): """ Map is defined as key and value Python's dict will be converted to Map """ def __init__(self, mapping): super().__init__(mapping)
[docs] def parameters_dict(self): """ Return a dictionary of all the parameters in the model, including both trainable and non-trainable parameters. The dict contains ParameterNodes or ParameterContainers. """ parameters = {} for k, v in self.data.items(): if isinstance(v, ParameterNode): parameters[k] = v elif isinstance(v, ParameterContainer): parameters[str(v)] = v # TODO: what is the name of the container? if isinstance(k, ParameterNode): parameters[str(k)] = k elif isinstance(k, ParameterContainer): raise Exception("The key of a Map cannot be a container.") assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values()) return parameters
#