Source code for archai.supergraph.nas.arch_params
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import UserDict
from typing import Iterable, Iterator, Optional, Tuple, Union
from torch import nn
_param_suffix = '_arch_param' # all arch parameter names must have this suffix
NNTypes = Union[nn.Parameter, nn.ParameterDict, nn.ParameterList]
[docs]class ArchParams(UserDict):
"""This class holds set of learnable architecture parameter(s) for a given module. For example, one instance of this class would hold alphas for one instance of MixedOp. For sharing parameters, instance of this class can be passed around. Different algorithms may add learnable parameters for their need."""
def __init__(self, arch_params:Iterable[Tuple[str, NNTypes]], registrar:Optional[nn.Module]=None):
"""Create architecture parameters and register them
Arguments:
registrar {Optional[nn.Module]} -- If this parameter is beingly newly created instead of being shared by other module then owner should be specified. When owner is not None, this method will create a variable in the owning module with suffix _arch_param so that the parameter gets registered with Pytorch and becomes available in module's .parameters() calls.
"""
super().__init__()
for name, param in arch_params:
self.data[name] = param
if registrar is not None:
setattr(registrar, name + _param_suffix, param)
def __setitem__(self, name:str, param:NNTypes)->None:
raise RuntimeError(f'ArchParams is immutable hence adding/updating key {name} is not allowed.')
def __delitem__(self, name:str) -> None:
raise RuntimeError(f'ArchParams is immutable hence removing key {name} is not allowed.')
def _by_kind(self, kind:Optional[str])->Iterator[NNTypes]:
# TODO: may be optimize to avoid split() calls?
for name, param in self.items():
if kind is None or name.split('.')[-1]==kind:
yield param
[docs] def param_by_kind(self, kind:Optional[str])->Iterator[nn.Parameter]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
[docs] def paramlist_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterList]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
[docs] def paramdict_by_kind(self, kind:Optional[str])->Iterator[nn.ParameterDict]:
# TODO: enforce type checking if debugger is active?
return self._by_kind(kind) # type: ignore
[docs] def has_kind(self, kind:str)->bool:
# TODO: may be optimize to avoid split() calls?
for name in self.keys():
if name.split('.')[-1]==kind:
return True
return False
[docs] @staticmethod
def from_module(module:nn.Module, recurse:bool=False)->'ArchParams':
suffix_len = len(_param_suffix)
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
arch_params = ((name[:-suffix_len], param) \
for name, param in module.named_parameters(recurse=recurse)
if name.endswith(_param_suffix))
return ArchParams(arch_params)
[docs] @staticmethod
def nonarch_from_module(module:nn.Module, recurse:bool=False)->Iterator[nn.Parameter]:
# Pytorch named params have . in name for each module, we pick last part and remove _arch_params prefix
return (param for name, param in module.named_parameters(recurse=recurse)
if not name.endswith(_param_suffix))
[docs] @staticmethod
def empty()->'ArchParams':
return ArchParams([])