Source code for archai.supergraph.nas.arch_module
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC
from typing import Iterable, Iterator, Optional, Tuple
from overrides import EnforceOverrides
from torch import nn
from archai.supergraph.nas.arch_params import ArchParams, NNTypes
[docs]class ArchModule(nn.Module, ABC, EnforceOverrides):
"""ArchModule enahnces nn.Module by making a clear separation between regular
weights and the architecture weights. The architecture parameters can be added
using `create_arch_params()` method and then accessed using `arch_params()` method."""
def __init__(self) -> None:
super().__init__()
# these are params module should use, they may be shared or created by this module
self._arch_params = ArchParams.empty()
# these are the params created and registerd in this module
self._owned_arch_params:Optional[ArchParams] = None
[docs] def create_arch_params(self, named_params:Iterable[Tuple[str, NNTypes]])->None:
if len(self._arch_params):
raise RuntimeError('Arch parameters for this module already exist')
self._owned_arch_params = ArchParams(named_params, registrar=self)
self.set_arch_params(self._owned_arch_params)
[docs] def set_arch_params(self, arch_params:ArchParams)->None:
if len(self._arch_params):
raise RuntimeError('Arch parameters for this module already exist')
self._arch_params = arch_params
[docs] def arch_params(self, recurse=False, only_owned=False)->ArchParams:
# note that we will cache lists on first calls, this doesn't allow
# dynamic parameters but it makes this frequent calls much faster
if not recurse:
if not only_owned:
return self._arch_params
else:
return ArchParams.from_module(self, recurse=False)
else:
if not only_owned:
raise NotImplementedError('Recursively getting shared and owned arch params not implemented yet')
else:
return ArchParams.from_module(self, recurse=True)
[docs] def all_owned(self)->ArchParams:
return self.arch_params(recurse=True, only_owned=True)
[docs] def nonarch_params(self, recurse:bool)->Iterator[nn.Parameter]:
return ArchParams.nonarch_from_module(self, recurse)