Source code for archai.supergraph.algos.nasbench101.nasbench101_op
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Optional
import torch
from overrides import overrides
from torch import Tensor
from archai.supergraph.nas.arch_params import ArchParams
from archai.supergraph.nas.model_desc import OpDesc
from archai.supergraph.nas.operations import Op
[docs]class NasBench101Op(Op):
def __init__(self, op_desc:OpDesc, arch_params: Optional[ArchParams], affine:bool):
super().__init__()
vertex_op_name = op_desc.params['vertex_op']
proj_first = op_desc.params['proj_first'] # first input needs projection
self._vertex_op = Op.create(OpDesc(vertex_op_name, params=op_desc.params,
in_len=1, trainables=None),
affine=affine, arch_params=None)
self._in_len = op_desc.in_len
self._proj_op = Op.create(OpDesc('convbnrelu_1x1', params=op_desc.params,
in_len=1, trainables=None),
affine=affine, arch_params=None) \
if proj_first else None
[docs] @overrides
def forward(self, x:List[Tensor]):
assert not isinstance(x, torch.Tensor)
assert len(x) == self._in_len
x0 = x[0] if not self._proj_first else self._proj_op(x[0])
s = sum(x[1:]) + x0
out = self._vertex_op(s)
return out