[docs]defmodel(cls):""" Wrap a class with this decorator. This helps collect parameters for the optimizer. This decorated class cannot be pickled. """classModelWrapper(cls,Module):passreturnModelWrapper
[docs]classModule(ParameterContainer):"""Module is a ParameterContainer which has a forward method."""
[docs]defsave(self,file_name):"""Save the parameters of the model to a file."""# detect if the directory existsdirectory=os.path.dirname(file_name)ifdirectory!="":os.makedirs(directory,exist_ok=True)withopen(file_name,"wb")asf:pickle.dump(copy.deepcopy(self.parameters_dict()),f)
[docs]defload(self,file_name):"""Load the parameters of the model from a file."""withopen(file_name,"rb")asf:loaded_data=pickle.load(f)self._set(loaded_data)
def_set(self,new_parameters):"""Set the parameters of the model from a dictionary. new_parameters is a ParamterContainer or a parameter dict. """assertisinstance(new_parameters,(dict,ParameterContainer))ifisinstance(new_parameters,ParameterContainer):new_parameters_dict=new_parameters.parameters_dict()else:new_parameters_dict=new_parameters# dictionaryparameters_dict=self.parameters_dict()assertall(kinnew_parameters_dictforkinparameters_dict.keys()),""" Not all model parameters are in the new parameters dictionary. """fork,vinnew_parameters_dict.items():ifkinparameters_dict:# if the parameter existsassertisinstance(v,(ParameterNode,ParameterContainer))parameters_dict[k]._set(v)else:# if the parameter does not existassertknotinself.__dict__setattr(self,k,v)