Source code for opto.optimizers.optimizer

from typing import Any, List, Dict

from opto.trace.nodes import ParameterNode, Node
from opto.trace.propagators import GraphPropagator
from opto.trace.propagators.propagators import Propagator
from opto.trace.utils import sum_feedback


[docs] class AbstractOptimizer: """An optimizer is responsible for updating the parameters based on the feedback.""" def __init__(self, parameters: List[ParameterNode], *args, **kwargs): assert type(parameters) is list assert all([isinstance(p, ParameterNode) for p in parameters]) self.parameters = parameters
[docs] def step(self): """Update the parameters based on the feedback.""" raise NotImplementedError
[docs] def zero_feedback(self): """Reset the feedback.""" raise NotImplementedError
@property def propagator(self): """Return a Propagator object that can be used to propagate feedback in backward.""" raise NotImplementedError
[docs] class Optimizer(AbstractOptimizer): def __init__(self, parameters: List[ParameterNode], *args, propagator: Propagator = None, **kwargs): super().__init__(parameters) propagator = propagator if propagator is not None else self.default_propagator() assert isinstance(propagator, Propagator) self._propagator = propagator @property def propagator(self): return self._propagator @property def trace_graph(self): """ Aggregate the graphs of all the parameters. """ return sum_feedback(self.parameters)
[docs] def step(self, *args, **kwargs): update_dict = self.propose(*args, **kwargs) self.update(update_dict)
[docs] def propose(self, *args, **kwargs): """Propose the new data of the parameters based on the feedback.""" return self._step(*args, **kwargs)
[docs] def update(self, update_dict: Dict[ParameterNode, Any]): """Update the trainable parameters given a dictionary of new data.""" for p, d in update_dict.items(): if p.trainable: p._data = d
[docs] def zero_feedback(self): for p in self.parameters: p.zero_feedback()
# Subclass should implement the methods below. def _step(self, *args, **kwargs) -> Dict[ParameterNode, Any]: """Return the new data of parameter nodes based on the feedback.""" raise NotImplementedError
[docs] def default_propagator(self): """Return the default Propagator object of the optimizer.""" return GraphPropagator()
[docs] def backward(self, node: Node, *args, **kwargs): """Propagate the feedback backward.""" return node.backward(*args, propagator=self.propagator, **kwargs)