Source code for opto.trace.modules
import os
import pickle
import copy
from opto.trace.containers import ParameterContainer
from opto.trace.nodes import ParameterNode
[docs]
def model(cls):
"""
Wrap a class with this decorator. This helps collect parameters for the optimizer. This decorated class cannot be pickled.
"""
class ModelWrapper(cls, Module):
pass
return ModelWrapper
[docs]
class Module(ParameterContainer):
""" Module is a ParameterContainer which has a forward method. """
[docs]
def forward(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
[docs]
def save(self, file_name):
""" Save the parameters of the model to a file."""
# detect if the directory exists
directory = os.path.dirname(file_name)
if directory != "":
os.makedirs(directory, exist_ok=True)
with open(file_name, "wb") as f:
pickle.dump(copy.deepcopy(self.parameters_dict()), f)
[docs]
def load(self, file_name):
""" Load the parameters of the model from a file."""
with open(file_name, "rb") as f:
loaded_data = pickle.load(f)
self._set(loaded_data)
def _set(self, new_parameters):
""" Set the parameters of the model from a dictionary.
new_parameters is a ParamterContainer or a parameter dict.
"""
assert isinstance(new_parameters, (dict, ParameterContainer))
if isinstance(new_parameters, ParameterContainer):
new_parameters_dict = new_parameters.parameters_dict()
else:
new_parameters_dict = new_parameters # dictionary
parameters_dict = self.parameters_dict()
assert all(k in new_parameters_dict for k in
parameters_dict.keys()), """ Not all model parameters are in the new parameters dictionary. """
for k, v in new_parameters_dict.items():
if k in parameters_dict: # if the parameter exists
assert isinstance(v, (ParameterNode, ParameterContainer))
parameters_dict[k]._set(v)
else: # if the parameter does not exist
assert k not in self.__dict__
setattr(self, k, v)