Skip to main content
Ctrl+K
Logo image

Getting Started

  • Installation
  • Package Structure
  • Quick Start
  • Notebooks
    • API
      • Dataset Provider
      • Trainer (Base)
    • Discrete Search
      • Search Spaces
      • Evaluators
      • Algorithms
      • Configuration-based Search
    • Computer Vision
      • Dataset Provider
      • PyTorch-Lightining Trainer
    • Natural Language Processing
      • Fast HF Dataset Provider
      • HF Dataset Provider
      • HF Trainer
      • NVIDIA Dataset Provider
      • NVIDIA Trainer
      • ONNX Export
      • PyTorch Quantization
      • Transformer++ Search Space

Advanced Guide

  • Cloud-Based Search
    • Azure
      • Notebooks
        • Quickstart
        • Text Generation
        • Multi node search

Contributing

  • First Time Contributor
  • Documentation
  • Unit Tests

Support

  • Frequently Asked Questions
  • Contact
  • Copyright

Reference

  • API
    • API
    • Common Packages
    • Datasets
      • Computer Vision
        • Transforms
      • Natural Language Processing
        • Tokenization Utilities
    • Discrete Search
      • Search Algorithms
      • API
      • Evaluators
        • Benchmark
        • Natural Language Processing
        • PyTorch Profiler (Utilities)
      • Predictors
      • Search Spaces
        • Benchmark
        • Configuration-Based
        • Computer Vision
        • Natural Language Processing
      • Utilities
    • ONNX
      • Configuration Utilities
      • Optimization Utilities
    • Quantization
      • Natural Language Processing
    • Supergraph
      • Algorithms
        • DARTS
        • DiDARTS
        • DivNAS
        • Gumbel-Softmax
        • Manual
        • NasBench-101
        • Petridish
        • Random
        • XNAS
      • Datasets
        • Providers
      • Models
        • ShakeShake
      • Neural Architecture Search
      • Utilities
    • Trainers
      • Computer Vision
      • Natural Language Processing
  • Roadmap
  • Changelog
  • Repository
  • Open issue

Neural Architecture Search

Sections

  • Architecture Module
    • ArchModule
      • ArchModule.create_arch_params()
      • ArchModule.set_arch_params()
      • ArchModule.arch_params()
      • ArchModule.all_owned()
      • ArchModule.nonarch_params()
      • ArchModule.training
  • Architecture Parameters
    • ArchParams
      • ArchParams.param_by_kind()
      • ArchParams.paramlist_by_kind()
      • ArchParams.paramdict_by_kind()
      • ArchParams.has_kind()
      • ArchParams.from_module()
      • ArchParams.nonarch_from_module()
      • ArchParams.empty()
  • Architecture Trainer
    • ArchTrainer
      • ArchTrainer.compute_loss()
      • ArchTrainer.post_epoch()
  • Cell
    • Cell
      • Cell.ops()
      • Cell.forward()
      • Cell.training
  • DAG Edge
    • DagEdge
      • DagEdge.forward()
      • DagEdge.op()
      • DagEdge.training
  • Evaluater
    • EvalResult
    • Evaluater
      • Evaluater.evaluate()
      • Evaluater.train_model()
      • Evaluater.get_data()
      • Evaluater.create_model()
      • Evaluater.model_from_desc()
  • Experiment Runner
    • ExperimentRunner
      • ExperimentRunner.run_search()
      • ExperimentRunner.run_eval()
      • ExperimentRunner.run()
      • ExperimentRunner.copy_search_to_eval()
      • ExperimentRunner.model_desc_builder()
      • ExperimentRunner.searcher()
      • ExperimentRunner.evaluater()
      • ExperimentRunner.trainer_class()
      • ExperimentRunner.finalizers()
      • ExperimentRunner.get_expname()
      • ExperimentRunner.get_conf()
  • Finalizers
    • Finalizers
      • Finalizers.finalize_model()
      • Finalizers.finalize_cells()
      • Finalizers.finalize_cell()
      • Finalizers.finalize_node()
      • Finalizers.select_edges()
      • Finalizers.get_edge_ranks()
      • Finalizers.finalize_edge()
  • Model
    • Model
      • Model.summary()
      • Model.ops()
      • Model.forward()
      • Model.device_type()
      • Model.drop_path_prob()
      • Model.training
    • AuxTower
      • AuxTower.forward()
      • AuxTower.training
  • Model Description
    • ConvMacroParams
      • ConvMacroParams.clone()
    • OpDesc
      • OpDesc.clone()
      • OpDesc.clear_trainables()
      • OpDesc.state_dict()
      • OpDesc.load_state_dict()
    • EdgeDesc
      • EdgeDesc.clone()
      • EdgeDesc.clear_trainables()
      • EdgeDesc.state_dict()
      • EdgeDesc.load_state_dict()
    • NodeDesc
      • NodeDesc.clone()
      • NodeDesc.clear_trainables()
      • NodeDesc.state_dict()
      • NodeDesc.load_state_dict()
    • AuxTowerDesc
    • CellType
      • CellType.Regular
      • CellType.Reduction
    • CellDesc
      • CellDesc.clone()
      • CellDesc.clear_trainables()
      • CellDesc.state_dict()
      • CellDesc.load_state_dict()
      • CellDesc.reset_nodes()
      • CellDesc.nodes()
      • CellDesc.all_empty()
      • CellDesc.all_full()
    • ModelDesc
      • ModelDesc.reset_cells()
      • ModelDesc.clear_trainables()
      • ModelDesc.cell_descs()
      • ModelDesc.cell_type_count()
      • ModelDesc.clone()
      • ModelDesc.has_aux_tower()
      • ModelDesc.all_empty()
      • ModelDesc.all_full()
      • ModelDesc.state_dict()
      • ModelDesc.load_state_dict()
      • ModelDesc.save()
      • ModelDesc.load()
  • Model Description Builder
    • ModelDescBuilder
      • ModelDescBuilder.get_reduction_indices()
      • ModelDescBuilder.get_node_channels()
      • ModelDescBuilder.get_conf_cell()
      • ModelDescBuilder.get_conf_dataset()
      • ModelDescBuilder.get_conf_model_stems()
      • ModelDescBuilder.build()
      • ModelDescBuilder.build_cells()
      • ModelDescBuilder.get_node_count()
      • ModelDescBuilder.build_cell()
      • ModelDescBuilder.get_trainables_from()
      • ModelDescBuilder.get_ch()
      • ModelDescBuilder.build_cell_stems()
      • ModelDescBuilder.build_nodes_from_template()
      • ModelDescBuilder.build_nodes()
      • ModelDescBuilder.create_cell_templates()
      • ModelDescBuilder.build_model_pool()
      • ModelDescBuilder.build_logits_op()
      • ModelDescBuilder.get_cell_template()
      • ModelDescBuilder.get_cell_type()
      • ModelDescBuilder.build_cell_post_op()
      • ModelDescBuilder.build_aux_tower()
      • ModelDescBuilder.build_model_stems()
      • ModelDescBuilder.pre_build()
      • ModelDescBuilder.seed_cell()
  • NAS-Based Utitilies
    • checkpoint_empty()
    • create_checkpoint()
    • get_model_stats()
  • Operations
    • Op
      • Op.create()
      • Op.get_trainables()
      • Op.set_trainables()
      • Op.register_op()
      • Op.finalize()
      • Op.ops()
      • Op.can_drop_path()
      • Op.training
    • PoolBN
      • PoolBN.forward()
      • PoolBN.training
    • SkipConnect
      • SkipConnect.forward()
      • SkipConnect.can_drop_path()
      • SkipConnect.training
    • FacConv
      • FacConv.forward()
      • FacConv.training
    • ReLUConvBN
      • ReLUConvBN.forward()
      • ReLUConvBN.training
    • ConvBNReLU
      • ConvBNReLU.forward()
      • ConvBNReLU.training
    • DilConv
      • DilConv.forward()
      • DilConv.training
    • SepConv
      • SepConv.forward()
      • SepConv.training
    • Identity
      • Identity.forward()
      • Identity.can_drop_path()
      • Identity.training
    • Zero
      • Zero.forward()
      • Zero.training
    • FactorizedReduce
      • FactorizedReduce.forward()
      • FactorizedReduce.training
    • StemBase
      • StemBase.training
    • StemConv3x3
      • StemConv3x3.forward()
      • StemConv3x3.can_drop_path()
      • StemConv3x3.training
    • StemConv3x3Relu
      • StemConv3x3Relu.forward()
      • StemConv3x3Relu.can_drop_path()
      • StemConv3x3Relu.training
    • StemConv3x3S4
      • StemConv3x3S4.forward()
      • StemConv3x3S4.can_drop_path()
      • StemConv3x3S4.training
    • StemConv3x3S4S2
      • StemConv3x3S4S2.forward()
      • StemConv3x3S4S2.can_drop_path()
      • StemConv3x3S4S2.training
    • AvgPool2d7x7
      • AvgPool2d7x7.forward()
      • AvgPool2d7x7.can_drop_path()
      • AvgPool2d7x7.training
    • PoolAdaptiveAvg2D
      • PoolAdaptiveAvg2D.forward()
      • PoolAdaptiveAvg2D.can_drop_path()
      • PoolAdaptiveAvg2D.training
    • PoolMeanTensor
      • PoolMeanTensor.forward()
      • PoolMeanTensor.can_drop_path()
      • PoolMeanTensor.training
    • LinearOp
      • LinearOp.forward()
      • LinearOp.can_drop_path()
      • LinearOp.training
    • MergeOp
      • MergeOp.forward()
      • MergeOp.can_drop_path()
      • MergeOp.training
    • ConcateChannelsOp
      • ConcateChannelsOp.forward()
      • ConcateChannelsOp.training
    • ProjectChannelsOp
      • ProjectChannelsOp.forward()
      • ProjectChannelsOp.training
    • DropPath_
      • DropPath_.extra_repr()
      • DropPath_.forward()
      • DropPath_.training
    • MultiOp
      • MultiOp.forward()
      • MultiOp.training
  • Random Finalizers
    • RandomFinalizers
      • RandomFinalizers.finalize_node()
  • Search Combinations
    • SearchCombinations
      • SearchCombinations.search()
      • SearchCombinations.is_better_metrics()
      • SearchCombinations.restore_checkpoint()
      • SearchCombinations.record_checkpoint()
      • SearchCombinations.get_combinations()
      • SearchCombinations.save_trained()
  • Searcher
    • ModelMetrics
    • SearchResult
    • Searcher
      • Searcher.search()
      • Searcher.clean_log_result()
      • Searcher.build_model_desc()
      • Searcher.get_data()
      • Searcher.finalize_model()
      • Searcher.search_model_desc()
      • Searcher.train_model_desc()
  • Model Description Visualizer
    • draw_model_desc()
    • draw_cell_desc()

Neural Architecture Search#

Architecture Module#

class archai.supergraph.nas.arch_module.ArchModule[source]#

ArchModule enahnces nn.Module by making a clear separation between regular weights and the architecture weights. The architecture parameters can be added using create_arch_params() method and then accessed using arch_params() method.

create_arch_params(named_params: Iterable[Tuple[str, Parameter | ParameterDict | ParameterList]]) → None[source]#
set_arch_params(arch_params: ArchParams) → None[source]#
arch_params(recurse=False, only_owned=False) → ArchParams[source]#
all_owned() → ArchParams[source]#
nonarch_params(recurse: bool) → Iterator[Parameter][source]#
training: bool#

Architecture Parameters#

class archai.supergraph.nas.arch_params.ArchParams(arch_params: Iterable[Tuple[str, Parameter | ParameterDict | ParameterList]], registrar: Module | None = None)[source]#

This class holds set of learnable architecture parameter(s) for a given module. For example, one instance of this class would hold alphas for one instance of MixedOp. For sharing parameters, instance of this class can be passed around. Different algorithms may add learnable parameters for their need.

param_by_kind(kind: str | None) → Iterator[Parameter][source]#
paramlist_by_kind(kind: str | None) → Iterator[ParameterList][source]#
paramdict_by_kind(kind: str | None) → Iterator[ParameterDict][source]#
has_kind(kind: str) → bool[source]#
static from_module(module: Module, recurse: bool = False) → ArchParams[source]#
static nonarch_from_module(module: Module, recurse: bool = False) → Iterator[Parameter][source]#
static empty() → ArchParams[source]#

Architecture Trainer#

class archai.supergraph.nas.arch_trainer.ArchTrainer(conf_train: Config, model: Model, checkpoint: CheckPoint | None)[source]#
compute_loss(lossfn: Callable, y: Tensor, logits: Tensor, aux_weight: float, aux_logits: Tensor | None) → Tensor[source]#
post_epoch(data_loaders: DataLoaders) → None[source]#

Cell#

class archai.supergraph.nas.cell.Cell(desc: CellDesc, affine: bool, droppath: bool, trainables_from: Cell | None)[source]#
ops() → Iterable[Op][source]#
forward(s0, s1)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

DAG Edge#

class archai.supergraph.nas.dag_edge.DagEdge(desc: EdgeDesc, affine: bool, droppath: bool, template_edge: DagEdge | None)[source]#
forward(inputs: List[Tensor])[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

op() → Op[source]#
training: bool#

Evaluater#

class archai.supergraph.nas.evaluater.EvalResult(train_metrics: Metrics)[source]#
class archai.supergraph.nas.evaluater.Evaluater[source]#
evaluate(conf_eval: Config, model_desc_builder: ModelDescBuilder) → EvalResult[source]#
train_model(conf_train: Config, model: Module, checkpoint: CheckPoint | None) → Metrics[source]#
get_data(conf_loader: Config) → DataLoaders[source]#
create_model(conf_eval: Config, model_desc_builder: ModelDescBuilder, final_desc_filename=None, full_desc_filename=None) → Module[source]#
model_from_desc(model_desc) → Model[source]#

Experiment Runner#

class archai.supergraph.nas.exp_runner.ExperimentRunner(config_filename: str, base_name: str, clean_expdir=False)[source]#
run_search(conf_search: Config) → SearchResult[source]#
run_eval(conf_eval: Config) → EvalResult[source]#
run(search=True, eval=True) → Tuple[SearchResult | None, EvalResult | None][source]#
copy_search_to_eval() → None[source]#
model_desc_builder() → ModelDescBuilder | None[source]#
searcher() → Searcher[source]#
evaluater() → Evaluater[source]#
abstract trainer_class() → Type[ArchTrainer] | None[source]#
finalizers() → Finalizers[source]#
get_expname(is_search_or_eval: bool) → str[source]#
get_conf(is_search_or_eval: bool) → Config[source]#

Finalizers#

class archai.supergraph.nas.finalizers.Finalizers[source]#

Provides base algorithms for finalizing model, cell and edge which can be overriden

For op-level finalize, just put logic in op’s finalize.

For model/cell/edge level finalize, you can override the methods in this class to customize the behavior. To override any of these methods, simply create new class in your algos folder, for example, diversity/diversity_finalizers.py. In this file create class that derives from Finalizers. Then in your algos exp_runner.py, return instance of that class in its finalizers() method.

finalize_model(model: Model, to_cpu=True, restore_device=True) → ModelDesc[source]#
finalize_cells(model: Model) → List[CellDesc][source]#
finalize_cell(cell: Cell, cell_index: int, model_desc: ModelDesc, *args, **kwargs) → CellDesc[source]#
finalize_node(node: ModuleList, node_index: int, node_desc: NodeDesc, max_final_edges: int, *args, **kwargs) → NodeDesc[source]#
select_edges(edge_desc_ranks: List[Tuple[EdgeDesc, float]], max_final_edges: int) → List[EdgeDesc][source]#
get_edge_ranks(node: ModuleList) → Tuple[List[EdgeDesc], List[Tuple[EdgeDesc, float]]][source]#
finalize_edge(edge) → Tuple[EdgeDesc, float | None][source]#

Model#

class archai.supergraph.nas.model.Model(model_desc: ModelDesc, droppath: bool, affine: bool)[source]#
summary() → dict[source]#
ops() → Iterable[Op][source]#
forward(x) → Tuple[Tensor, Tensor | None][source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

device_type() → str[source]#
drop_path_prob(p: float)[source]#

Set drop path probability.

This will be called externally so any DropPath_ modules get new probability. Typically, every epoch we will reduce this probability.

training: bool#
class archai.supergraph.nas.model.AuxTower(aux_tower_desc: AuxTowerDesc)[source]#
forward(x: Tensor)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

Model Description#

Note: All classes in this file needs to be deepcopy compatible because

descs are used as template to create copies by macro builder.

class archai.supergraph.nas.model_desc.ConvMacroParams(ch_in: int, ch_out: int)[source]#

Holds parameters that may be altered by macro architecture

clone() → ConvMacroParams[source]#
class archai.supergraph.nas.model_desc.OpDesc(name: str, params: dict, in_len: int, trainables: Mapping | None, children: List[OpDesc] | None = None, children_ins: List[int] | None = None)[source]#

Op description that is in each edge

clone(clone_trainables=True) → OpDesc[source]#
clear_trainables() → None[source]#
state_dict() → dict[source]#
load_state_dict(state_dict) → None[source]#
class archai.supergraph.nas.model_desc.EdgeDesc(op_desc: OpDesc, input_ids: List[int])[source]#

Edge description between two nodes in the cell

clone(conv_params: ConvMacroParams | None, clear_trainables: bool) → EdgeDesc[source]#
clear_trainables() → None[source]#
state_dict() → dict[source]#
load_state_dict(state_dict) → None[source]#
class archai.supergraph.nas.model_desc.NodeDesc(edges: List[EdgeDesc], conv_params: ConvMacroParams)[source]#
clone()[source]#
clear_trainables() → None[source]#
state_dict() → dict[source]#
load_state_dict(state_dict) → None[source]#
class archai.supergraph.nas.model_desc.AuxTowerDesc(ch_in: int, n_classes: int, stride: int)[source]#
class archai.supergraph.nas.model_desc.CellType(value)[source]#

An enumeration.

Regular = 'regular'#
Reduction = 'reduction'#
class archai.supergraph.nas.model_desc.CellDesc(id: int, cell_type: CellType, conf_cell: Config, stems: List[OpDesc], stem_shapes: List[List[int | float]], nodes: List[NodeDesc], node_shapes: List[List[int | float]], post_op: OpDesc, out_shape: List[int | float], trainables_from: int)[source]#
clone(id: int) → CellDesc[source]#
clear_trainables() → None[source]#
state_dict() → dict[source]#
load_state_dict(state_dict) → None[source]#
reset_nodes(nodes: List[NodeDesc], node_shapes: List[List[int | float]], post_op: OpDesc, out_shape: List[int | float]) → None[source]#
nodes() → List[NodeDesc][source]#
all_empty() → bool[source]#
all_full() → bool[source]#
class archai.supergraph.nas.model_desc.ModelDesc(conf_model_desc: Config, model_stems: List[OpDesc], pool_op: OpDesc, cell_descs: List[CellDesc], aux_tower_descs: List[AuxTowerDesc | None], logits_op: OpDesc)[source]#
reset_cells(cell_descs: List[CellDesc], aux_tower_descs: List[AuxTowerDesc | None]) → None[source]#
clear_trainables() → None[source]#
cell_descs() → List[CellDesc][source]#
cell_type_count(cell_type: CellType) → int[source]#
clone() → ModelDesc[source]#
has_aux_tower() → bool[source]#
all_empty() → bool[source]#
all_full() → bool[source]#
state_dict() → dict[source]#
load_state_dict(state_dict) → None[source]#
save(filename: str, save_trainables=False) → str | None[source]#
static load(filename: str, load_trainables=False) → ModelDesc[source]#

Model Description Builder#

class archai.supergraph.nas.model_desc_builder.ModelDescBuilder[source]#
get_reduction_indices(conf_model_desc: Config) → List[int][source]#

Returns cell indices which reduces HxW and doubles channels

get_node_channels(conf_model_desc: Config) → List[List[int]][source]#

Returns array of channels for each node in each cell. All nodes are assumed to have same output channels as input channels.

get_conf_cell() → Config[source]#
get_conf_dataset() → Config[source]#
get_conf_model_stems() → Config[source]#
build(conf_model_desc: Config, template: ModelDesc | None = None) → ModelDesc[source]#

main entry point for the class

build_cells(in_shapes: List[List[List[int | float]]], conf_model_desc: Config) → Tuple[List[CellDesc], List[AuxTowerDesc | None]][source]#
get_node_count(cell_index: int) → int[source]#
build_cell(in_shapes: List[List[List[int | float]]], conf_cell: Config, cell_index: int) → CellDesc[source]#
get_trainables_from(cell_index: int) → int[source]#
get_ch(shape: List[int | float]) → int[source]#
build_cell_stems(in_shapes: List[List[List[int | float]]], conf_cell: Config, cell_index: int) → Tuple[List[List[int | float]], List[OpDesc]][source]#
build_nodes_from_template(stem_shapes: List[List[int | float]], conf_cell: Config, cell_index: int) → Tuple[List[List[int | float]], List[NodeDesc]][source]#
build_nodes(stem_shapes: List[List[int | float]], conf_cell: Config, cell_index: int, cell_type: CellType, node_count: int, in_shape: List[int | float], out_shape: List[int | float]) → Tuple[List[List[int | float]], List[NodeDesc]][source]#
create_cell_templates(template: ModelDesc | None) → List[CellDesc | None][source]#
build_model_pool(in_shapes: List[List[List[int | float]]], conf_model_desc: Config) → OpDesc[source]#
build_logits_op(in_shapes: List[List[List[int | float]]], conf_model_desc: Config) → OpDesc[source]#
get_cell_template(cell_index: int) → CellDesc | None[source]#
get_cell_type(cell_index: int) → CellType[source]#
build_cell_post_op(stem_shapes: List[List[int | float]], node_shapes: List[List[int | float]], conf_cell: Config, cell_index: int) → Tuple[List[int | float], OpDesc][source]#
build_aux_tower(out_shape: List[int | float], conf_model_desc: Config, cell_index: int) → AuxTowerDesc | None[source]#
build_model_stems(in_shapes: List[List[List[int | float]]], conf_model_desc: Config) → List[OpDesc][source]#
pre_build(conf_model_desc: Config) → None[source]#

hook for accomplishing any setup before build starts

seed_cell(model_desc: ModelDesc) → None[source]#

NAS-Based Utitilies#

archai.supergraph.nas.nas_utils.checkpoint_empty(checkpoint: CheckPoint | None) → bool[source]#
archai.supergraph.nas.nas_utils.create_checkpoint(conf_checkpoint: Config, resume: bool) → CheckPoint | None[source]#

Creates checkpoint given its config. If resume is True then attempt is made to load existing checkpoint otherwise an empty checkpoint is created.

archai.supergraph.nas.nas_utils.get_model_stats(model: Model, input_tensor_shape=[1, 3, 32, 32], clone_model=True) → ModelStats[source]#

Operations#

class archai.supergraph.nas.operations.Op[source]#
static create(op_desc: OpDesc, affine: bool, arch_params: ArchParams | None = None) → Op[source]#
get_trainables() → Mapping[source]#
set_trainables(state_dict) → None[source]#
static register_op(name: str, factory_fn: Callable, exists_ok=True) → None[source]#
finalize() → Tuple[OpDesc, float | None][source]#

for trainable op, return final op and its rank

ops() → Iterator[Tuple[Op, float]][source]#

Return contituent ops, if this op is primitive just return self

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.PoolBN(pool_type: str, op_desc: OpDesc, affine: bool)[source]#

AvgPool or MaxPool - BN

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.SkipConnect(op_desc: OpDesc, affine)[source]#
forward(x: Tensor) → Tensor[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.FacConv(op_desc: OpDesc, kernel_length: int, padding: int, affine: bool)[source]#

Factorized conv ReLU - Conv(Kx1) - Conv(1xK) - BN

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.ReLUConvBN(op_desc: OpDesc, kernel_size: int, stride: int, padding: int, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.ConvBNReLU(op_desc: OpDesc, kernel_size: int, stride: int, padding: int, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.DilConv(op_desc: OpDesc, kernel_size: int, stride: int, padding: int, dilation: int, affine: bool)[source]#

(Dilated) depthwise separable conv ReLU - (Dilated) depthwise separable - Pointwise - BN

If dilation == 2, 3x3 conv => 5x5 receptive field

5x5 conv => 9x9 receptive field

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.SepConv(op_desc: OpDesc, kernel_size: int, padding: int, affine: bool)[source]#

Depthwise separable conv DilConv(dilation=1) * 2

This is same as two DilConv stacked with dilation=1

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.Identity(op_desc: OpDesc)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.Zero(op_desc: OpDesc)[source]#

Represents no connection. Zero op can be thought of 1x1 kernel with fixed zero weight. For stride=1, it will produce output of same dimension as input but with all 0s. Now with stride of 2, it will zero out every other pixel in output.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.FactorizedReduce(op_desc: OpDesc, affine: bool)[source]#

reduce feature maps height/width by 2X while doubling channels using two 1x1 convs, each with stride=2.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.StemBase(reduction: int)[source]#

Abstract base class for model stems that enforces reduction property indicating amount of spatial map reductions performed by stem, i.e., reduction=2 for each stride=2

training: bool#
class archai.supergraph.nas.operations.StemConv3x3(op_desc: OpDesc, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.StemConv3x3Relu(op_desc: OpDesc, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.StemConv3x3S4(op_desc, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.StemConv3x3S4S2(op_desc, affine: bool)[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.AvgPool2d7x7[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.PoolAdaptiveAvg2D[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.PoolMeanTensor[source]#
forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.LinearOp(op_desc: OpDesc)[source]#
forward(x: Tensor)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.MergeOp(op_desc: OpDesc, affine: bool)[source]#
forward(states: List[Tensor])[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

can_drop_path() → bool[source]#
training: bool#
class archai.supergraph.nas.operations.ConcateChannelsOp(op_desc: OpDesc, affine: bool)[source]#
forward(states: List[Tensor])[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.ProjectChannelsOp(op_desc: OpDesc, affine: bool)[source]#
forward(states: List[Tensor])[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.DropPath_(p: float = 0.0)[source]#

Replace values in tensor by 0. with probability p Ref: https://arxiv.org/abs/1605.07648

extra_repr()[source]#

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class archai.supergraph.nas.operations.MultiOp(op_desc: OpDesc, affine: bool)[source]#
forward(x: Tensor | List[Tensor]) → Tensor[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#

Random Finalizers#

class archai.supergraph.nas.random_finalizers.RandomFinalizers[source]#
finalize_node(node: ModuleList, node_index: int, node_desc: NodeDesc, max_final_edges: int, *args, **kwargs) → NodeDesc[source]#

Search Combinations#

class archai.supergraph.nas.search_combinations.SearchCombinations[source]#
search(conf_search: Config, model_desc_builder: ModelDescBuilder, trainer_class: Type[ArchTrainer] | None, finalizers: Finalizers) → SearchResult[source]#
is_better_metrics(metrics1: Metrics | None, metrics2: Metrics | None) → bool[source]#
restore_checkpoint(conf_search: Config, macro_combinations) → Tuple[int, SearchResult | None][source]#
record_checkpoint(macro_comb_i: int, best_result: SearchResult) → None[source]#
get_combinations(conf_search: Config) → Iterator[Tuple[int, int, int]][source]#
save_trained(conf_search: Config, reductions: int, cells: int, nodes: int, model_metrics: ModelMetrics) → None[source]#

Save the model and metric info into a log file

Searcher#

class archai.supergraph.nas.searcher.ModelMetrics(model: Model, metrics: Metrics)[source]#
class archai.supergraph.nas.searcher.SearchResult(model_desc: ModelDesc | None, search_metrics: Metrics | None, train_metrics: Metrics | None)[source]#
class archai.supergraph.nas.searcher.Searcher[source]#
search(conf_search: Config, model_desc_builder: ModelDescBuilder | None, trainer_class: Type[ArchTrainer] | None, finalizers: Finalizers) → SearchResult[source]#
clean_log_result(conf_search: Config, search_result: SearchResult) → None[source]#
build_model_desc(model_desc_builder: ModelDescBuilder, conf_model_desc: Config, reductions: int, cells: int, nodes: int) → ModelDesc[source]#
get_data(conf_loader: Config) → DataLoaders[source]#
finalize_model(model: Model, finalizers: Finalizers) → ModelDesc[source]#
search_model_desc(conf_search: Config, model_desc: ModelDesc, trainer_class: Type[ArchTrainer] | None, finalizers: Finalizers) → Tuple[ModelDesc, Metrics | None][source]#
train_model_desc(model_desc: ModelDesc, conf_train: Config) → ModelMetrics | None[source]#

Train given description

Model Description Visualizer#

archai.supergraph.nas.vis_model_desc.draw_model_desc(model_desc: ModelDesc, filepath: str | None = None, caption: str | None = None) → Tuple[Digraph | None, Digraph | None][source]#
archai.supergraph.nas.vis_model_desc.draw_cell_desc(cell_desc: CellDesc, filepath: str | None = None, caption: str | None = None) → Digraph[source]#

make DAG plot and optionally save to filepath as .png

previous

ShakeShake

next

Utilities

Sections
  • Architecture Module
    • ArchModule
      • ArchModule.create_arch_params()
      • ArchModule.set_arch_params()
      • ArchModule.arch_params()
      • ArchModule.all_owned()
      • ArchModule.nonarch_params()
      • ArchModule.training
  • Architecture Parameters
    • ArchParams
      • ArchParams.param_by_kind()
      • ArchParams.paramlist_by_kind()
      • ArchParams.paramdict_by_kind()
      • ArchParams.has_kind()
      • ArchParams.from_module()
      • ArchParams.nonarch_from_module()
      • ArchParams.empty()
  • Architecture Trainer
    • ArchTrainer
      • ArchTrainer.compute_loss()
      • ArchTrainer.post_epoch()
  • Cell
    • Cell
      • Cell.ops()
      • Cell.forward()
      • Cell.training
  • DAG Edge
    • DagEdge
      • DagEdge.forward()
      • DagEdge.op()
      • DagEdge.training
  • Evaluater
    • EvalResult
    • Evaluater
      • Evaluater.evaluate()
      • Evaluater.train_model()
      • Evaluater.get_data()
      • Evaluater.create_model()
      • Evaluater.model_from_desc()
  • Experiment Runner
    • ExperimentRunner
      • ExperimentRunner.run_search()
      • ExperimentRunner.run_eval()
      • ExperimentRunner.run()
      • ExperimentRunner.copy_search_to_eval()
      • ExperimentRunner.model_desc_builder()
      • ExperimentRunner.searcher()
      • ExperimentRunner.evaluater()
      • ExperimentRunner.trainer_class()
      • ExperimentRunner.finalizers()
      • ExperimentRunner.get_expname()
      • ExperimentRunner.get_conf()
  • Finalizers
    • Finalizers
      • Finalizers.finalize_model()
      • Finalizers.finalize_cells()
      • Finalizers.finalize_cell()
      • Finalizers.finalize_node()
      • Finalizers.select_edges()
      • Finalizers.get_edge_ranks()
      • Finalizers.finalize_edge()
  • Model
    • Model
      • Model.summary()
      • Model.ops()
      • Model.forward()
      • Model.device_type()
      • Model.drop_path_prob()
      • Model.training
    • AuxTower
      • AuxTower.forward()
      • AuxTower.training
  • Model Description
    • ConvMacroParams
      • ConvMacroParams.clone()
    • OpDesc
      • OpDesc.clone()
      • OpDesc.clear_trainables()
      • OpDesc.state_dict()
      • OpDesc.load_state_dict()
    • EdgeDesc
      • EdgeDesc.clone()
      • EdgeDesc.clear_trainables()
      • EdgeDesc.state_dict()
      • EdgeDesc.load_state_dict()
    • NodeDesc
      • NodeDesc.clone()
      • NodeDesc.clear_trainables()
      • NodeDesc.state_dict()
      • NodeDesc.load_state_dict()
    • AuxTowerDesc
    • CellType
      • CellType.Regular
      • CellType.Reduction
    • CellDesc
      • CellDesc.clone()
      • CellDesc.clear_trainables()
      • CellDesc.state_dict()
      • CellDesc.load_state_dict()
      • CellDesc.reset_nodes()
      • CellDesc.nodes()
      • CellDesc.all_empty()
      • CellDesc.all_full()
    • ModelDesc
      • ModelDesc.reset_cells()
      • ModelDesc.clear_trainables()
      • ModelDesc.cell_descs()
      • ModelDesc.cell_type_count()
      • ModelDesc.clone()
      • ModelDesc.has_aux_tower()
      • ModelDesc.all_empty()
      • ModelDesc.all_full()
      • ModelDesc.state_dict()
      • ModelDesc.load_state_dict()
      • ModelDesc.save()
      • ModelDesc.load()
  • Model Description Builder
    • ModelDescBuilder
      • ModelDescBuilder.get_reduction_indices()
      • ModelDescBuilder.get_node_channels()
      • ModelDescBuilder.get_conf_cell()
      • ModelDescBuilder.get_conf_dataset()
      • ModelDescBuilder.get_conf_model_stems()
      • ModelDescBuilder.build()
      • ModelDescBuilder.build_cells()
      • ModelDescBuilder.get_node_count()
      • ModelDescBuilder.build_cell()
      • ModelDescBuilder.get_trainables_from()
      • ModelDescBuilder.get_ch()
      • ModelDescBuilder.build_cell_stems()
      • ModelDescBuilder.build_nodes_from_template()
      • ModelDescBuilder.build_nodes()
      • ModelDescBuilder.create_cell_templates()
      • ModelDescBuilder.build_model_pool()
      • ModelDescBuilder.build_logits_op()
      • ModelDescBuilder.get_cell_template()
      • ModelDescBuilder.get_cell_type()
      • ModelDescBuilder.build_cell_post_op()
      • ModelDescBuilder.build_aux_tower()
      • ModelDescBuilder.build_model_stems()
      • ModelDescBuilder.pre_build()
      • ModelDescBuilder.seed_cell()
  • NAS-Based Utitilies
    • checkpoint_empty()
    • create_checkpoint()
    • get_model_stats()
  • Operations
    • Op
      • Op.create()
      • Op.get_trainables()
      • Op.set_trainables()
      • Op.register_op()
      • Op.finalize()
      • Op.ops()
      • Op.can_drop_path()
      • Op.training
    • PoolBN
      • PoolBN.forward()
      • PoolBN.training
    • SkipConnect
      • SkipConnect.forward()
      • SkipConnect.can_drop_path()
      • SkipConnect.training
    • FacConv
      • FacConv.forward()
      • FacConv.training
    • ReLUConvBN
      • ReLUConvBN.forward()
      • ReLUConvBN.training
    • ConvBNReLU
      • ConvBNReLU.forward()
      • ConvBNReLU.training
    • DilConv
      • DilConv.forward()
      • DilConv.training
    • SepConv
      • SepConv.forward()
      • SepConv.training
    • Identity
      • Identity.forward()
      • Identity.can_drop_path()
      • Identity.training
    • Zero
      • Zero.forward()
      • Zero.training
    • FactorizedReduce
      • FactorizedReduce.forward()
      • FactorizedReduce.training
    • StemBase
      • StemBase.training
    • StemConv3x3
      • StemConv3x3.forward()
      • StemConv3x3.can_drop_path()
      • StemConv3x3.training
    • StemConv3x3Relu
      • StemConv3x3Relu.forward()
      • StemConv3x3Relu.can_drop_path()
      • StemConv3x3Relu.training
    • StemConv3x3S4
      • StemConv3x3S4.forward()
      • StemConv3x3S4.can_drop_path()
      • StemConv3x3S4.training
    • StemConv3x3S4S2
      • StemConv3x3S4S2.forward()
      • StemConv3x3S4S2.can_drop_path()
      • StemConv3x3S4S2.training
    • AvgPool2d7x7
      • AvgPool2d7x7.forward()
      • AvgPool2d7x7.can_drop_path()
      • AvgPool2d7x7.training
    • PoolAdaptiveAvg2D
      • PoolAdaptiveAvg2D.forward()
      • PoolAdaptiveAvg2D.can_drop_path()
      • PoolAdaptiveAvg2D.training
    • PoolMeanTensor
      • PoolMeanTensor.forward()
      • PoolMeanTensor.can_drop_path()
      • PoolMeanTensor.training
    • LinearOp
      • LinearOp.forward()
      • LinearOp.can_drop_path()
      • LinearOp.training
    • MergeOp
      • MergeOp.forward()
      • MergeOp.can_drop_path()
      • MergeOp.training
    • ConcateChannelsOp
      • ConcateChannelsOp.forward()
      • ConcateChannelsOp.training
    • ProjectChannelsOp
      • ProjectChannelsOp.forward()
      • ProjectChannelsOp.training
    • DropPath_
      • DropPath_.extra_repr()
      • DropPath_.forward()
      • DropPath_.training
    • MultiOp
      • MultiOp.forward()
      • MultiOp.training
  • Random Finalizers
    • RandomFinalizers
      • RandomFinalizers.finalize_node()
  • Search Combinations
    • SearchCombinations
      • SearchCombinations.search()
      • SearchCombinations.is_better_metrics()
      • SearchCombinations.restore_checkpoint()
      • SearchCombinations.record_checkpoint()
      • SearchCombinations.get_combinations()
      • SearchCombinations.save_trained()
  • Searcher
    • ModelMetrics
    • SearchResult
    • Searcher
      • Searcher.search()
      • Searcher.clean_log_result()
      • Searcher.build_model_desc()
      • Searcher.get_data()
      • Searcher.finalize_model()
      • Searcher.search_model_desc()
      • Searcher.train_model_desc()
  • Model Description Visualizer
    • draw_model_desc()
    • draw_cell_desc()

By Microsoft

© Copyright 2023.

Last updated on Apr 27, 2023.