Source code for archai.discrete_search.search_spaces.benchmark.natsbench_tss

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import random
import re
import warnings
from pathlib import Path
from typing import Any, List, Optional

import nats_bench
import numpy as np
import torch
import yaml
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_space import (
    BayesOptSearchSpace,
    EvolutionarySearchSpace,
)


[docs]class NatsbenchTssSearchSpace(EvolutionarySearchSpace, BayesOptSearchSpace): """Search space for NATS-Bench-TSS.""" OPS = ["none", "avg_pool_3x3", "nor_conv_1x1", "nor_conv_3x3", "skip_connect"] try: from xautodl.models import get_cell_based_tiny_net except ImportError: raise ImportError( "xautodl installation not found. " "Please install it using `pip install git+https://github.com/D-X-Y/AutoDL-Projects/`" ) def __init__(self, natsbench_location: str, base_dataset: str, seed: Optional[int] = 1) -> None: """Initialize the search space. Args: natsbench_location: Path to the NATS-Bench-TSS dataset. base_dataset: Base dataset used for training the models. seed: Random seed. """ self.natsbench_location = Path(natsbench_location) self.base_dataset = base_dataset assert base_dataset in [ "cifar10", "cifar100", "ImageNet16-120", ], "`base_dataset` must be one of ['cifar10', 'cifar100', 'ImageNet16-120']" if not self.natsbench_location.exists(): raise FileNotFoundError( "The provided path to `natsbench_location` (" f"{self.natsbench_location.absolute()}) does not exist" ) self.api = nats_bench.create(natsbench_location, "tss", fast_mode=True, verbose=False) self.rng = random.Random(seed) self.archid_pattern = re.compile("natsbench-tss-([0-9]+)") def _get_op_list(self, string: str) -> List[str]: """Reused from https://github.com/naszilla/naszilla/blob/master/naszilla/nas_bench_201/cell_201.py.""" # Given a string, get the list of operations tokens = string.split("|") ops = [t.split("~")[0] for i, t in enumerate(tokens) if i not in [0, 2, 5, 9]] return ops def _get_string_from_ops(self, ops: List[str]) -> str: """Reused from https://github.com/naszilla/naszilla/blob/master/naszilla/nas_bench_201/cell_201.py.""" # Given a list of operations, get the string strings = ["|"] nodes = [0, 0, 1, 0, 1, 2] for i, op in enumerate(ops): strings.append(op + "~{}|".format(nodes[i])) if i < len(nodes) - 1 and nodes[i + 1] == 0: strings.append("+|") return "".join(strings)
[docs] def model_from_natsbench_tss(self, natsbench_id: int) -> Any: """Get a model from NATS-Bench-TSS dataset. Args: natsbench_id: NATS-Bench-TSS identifier. Returns: Model from NATS-Bench-TSS dataset. """ config = self.api.get_net_config(natsbench_id, self.base_dataset) return self.get_cell_based_tiny_net(config)
[docs] @overrides def save_arch(self, model: ArchaiModel, path: str) -> None: yaml.safe_dump({"archid": model.archid, **model.metadata}, open(path, "w", encoding="utf-8"))
[docs] @overrides def load_arch(self, path: str) -> ArchaiModel: metadata = yaml.safe_load(open(path, encoding="utf-8")) natsbenchid = self.archid_pattern.match(metadata["archid"]) if not natsbenchid: raise ValueError(f'Architecture {metadata["archid"]} does not belong to `NatsbenchTssSearchSpace`. ') if metadata["dataset"] != self.base_dataset: warnings.warn( f'Architecture loaded from {path} was saved using a different dataset ({metadata["dataset"]})' f" than `NatsbenchTssSearchSpace` base dataset ({self.base_dataset})" ) idx = int(natsbenchid.group(1)) return ArchaiModel( arch=self.model_from_natsbench_tss(idx), archid=f"natsbench-tss-{idx}", metadata={"dataset": self.base_dataset}, )
[docs] @overrides def load_model_weights(self, model: ArchaiModel, path: str) -> None: model.arch.load_state_dict(torch.load(path))
[docs] @overrides def save_model_weights(self, model: ArchaiModel, path: str) -> None: torch.save(model.arch.state_dict(), path)
[docs] @overrides def random_sample(self) -> ArchaiModel: idx = self.rng.randint(0, len(self.api)) return ArchaiModel( arch=self.model_from_natsbench_tss(idx), archid=f"natsbench-tss-{idx}", metadata={"dataset": self.base_dataset}, )
[docs] @overrides def mutate(self, model: ArchaiModel) -> ArchaiModel: """Reused from https://github.com/naszilla/naszilla/blob/master/naszilla/nas_bench_201/cell_201.py.""" # First get the string representation of the current architecture natsbenchid = self.archid_pattern.match(model.archid) if not natsbenchid: raise ValueError(f"Architecture {model.archid} does not belong to the `NatsbenchTssSearchSpace`. ") natsbenchid = int(natsbenchid.group(1)) string_rep = self.api.get_net_config(natsbenchid, self.base_dataset)["arch_str"] nbhd_strs = [] ops = self._get_op_list(string_rep) for i in range(len(ops)): available = [op for op in self.OPS if op != ops[i]] for op in available: new_ops = ops.copy() new_ops[i] = op new_arch_str = self._get_string_from_ops(new_ops) nbhd_strs.append(new_arch_str) # Picks one neighbor architecture as the mutation mutation_str = random.choice(nbhd_strs) mutation_natsbenchid = self.api.archstr2index[mutation_str] return ArchaiModel( arch=self.model_from_natsbench_tss(mutation_natsbenchid), archid=f"natsbench-tss-{mutation_natsbenchid}", metadata={"dataset": self.base_dataset}, )
[docs] @overrides def crossover(self, arch_list: List[ArchaiModel]) -> ArchaiModel: raise NotImplementedError
[docs] @overrides def encode(self, arch: ArchaiModel) -> np.ndarray: enc_dict = { "none": [0, 0, 0, 0], "avg_pool_3x3": [1, 0, 0, 0], "nor_conv_1x1": [0, 1, 0, 0], "nor_conv_3x3": [0, 0, 1, 0], "skip_connect": [0, 0, 0, 1], } # Gets string repr for `arch` natsbenchid = self.archid_pattern.match(arch.archid) if not natsbenchid: raise ValueError(f"Architecture {arch.archid} does not belong" " to `NatsbenchTssSearchSpace`. ") arch_str = self.api[int(natsbenchid.group(1))] arch_ops = re.findall(r"([^\|\~\+]+)~\d", arch_str) return np.hstack([np.array(enc_dict[op_name]) for op_name in arch_ops])