Source code for archai.discrete_search.search_spaces.config.search_space

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import hashlib
from random import Random
from typing import Any, Callable, Dict, List, Optional, Type, Union

import numpy as np
import torch
from overrides import overrides

from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.search_space import (
    BayesOptSearchSpace,
    EvolutionarySearchSpace,
)
from archai.discrete_search.search_spaces.config import utils
from archai.discrete_search.search_spaces.config.arch_config import (
    ArchConfig,
    build_arch_config,
)
from archai.discrete_search.search_spaces.config.arch_param_tree import ArchParamTree


[docs]class ConfigSearchSpace(EvolutionarySearchSpace, BayesOptSearchSpace): def __init__( self, model_cls: Type[torch.nn.Module], arch_param_tree: Union[ArchParamTree, Callable[..., ArchParamTree]], seed: Optional[int] = None, mutation_prob: float = 0.3, track_unused_params: bool = True, unused_param_value: float = -1.0, hash_archid: bool = True, model_kwargs: Optional[Dict[str, Any]] = None, builder_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Config-based Discrete Search Space. Args: model_cls (Type[torch.nn.Module]): Model class. This class expects that the first argument from `model_cls` constructor an `ArchConfig` object. arch_param_tree (Union[ArchParamTree, Callable[..., ArchParamTree]]): `ArchParamTree` object or a builder function that returns an `ArchParamTree` object. seed (int, optional): Random seed used for sampling, mutations and crossovers. Defaults to None. mutation_prob (float, optional): Probability of mutating a parameter. Defaults to 0.3. track_unused_params (bool, optional): Whether to track unused parameters. Defaults to True. unused_param_value (int, optional): Value to use for unused parameters. Defaults to `float('NaN')`. hash_archid (bool, optional): Weather to hash architecture identifiers. Defaults to True. model_kwargs: Additional arguments to pass to `model_cls` constructor. builder_kwargs: Arguments to pass to `arch_param_tree` if a builder function is passed. """ self.model_cls = model_cls self.arch_param_tree = arch_param_tree self.mutation_prob = mutation_prob self.track_unused_params = track_unused_params self.unused_param_value = unused_param_value self.model_kwargs = model_kwargs or {} self.builder_kwargs = builder_kwargs or {} self.hash_archid = hash_archid if callable(self.arch_param_tree): self.arch_param_tree = self.arch_param_tree(**self.builder_kwargs) self.rng = Random(seed)
[docs] def get_archid(self, arch_config: ArchConfig) -> str: """Return the architecture identifier for the given architecture configuration. Args: arch_config: Architecture configuration. Returns: Architecture identifier. """ archid = self.arch_param_tree.encode_config(arch_config, track_unused_params=self.track_unused_params) archid = str(tuple(archid)) if self.hash_archid: archid = hashlib.sha1(archid.encode("utf-8")).hexdigest() return archid
[docs] @overrides def save_arch(self, model: ArchaiModel, path: str) -> None: model.metadata["config"].to_file(path)
[docs] @overrides def load_arch(self, path: str) -> ArchaiModel: config = ArchConfig.from_file(path) model = self.model_cls(config, **self.model_kwargs) return ArchaiModel(arch=model, archid=self.get_archid(config), metadata={"config": config})
[docs] @overrides def save_model_weights(self, model: ArchaiModel, path: str) -> None: torch.save(model.arch.get_state_dict(), path)
[docs] @overrides def load_model_weights(self, model: ArchaiModel, path: str) -> None: model.arch.load_state_dict(torch.load(path))
[docs] @overrides def random_sample(self) -> ArchaiModel: config = self.arch_param_tree.sample_config(self.rng) model = self.model_cls(config, **self.model_kwargs) return ArchaiModel(arch=model, archid=self.get_archid(config), metadata={"config": config})
[docs] @overrides def mutate(self, model: ArchaiModel) -> ArchaiModel: choices_dict = self.arch_param_tree.to_dict() # Mutates parameter with probability `self.mutation_prob` mutated_dict = utils.replace_ptree_pair_choices( choices_dict, model.metadata["config"].to_dict(), lambda d_choice, current_choice: ( self.rng.choice(d_choice.choices) if self.rng.random() < self.mutation_prob else current_choice ), ) mutated_config = build_arch_config(mutated_dict) mutated_model = self.model_cls(mutated_config, **self.model_kwargs) return ArchaiModel( arch=mutated_model, archid=self.get_archid(mutated_config), metadata={"config": mutated_config} )
[docs] @overrides def crossover(self, model_list: List[ArchaiModel]) -> ArchaiModel: # Selects two models from `model_list` to perform crossover model_1, model_2 = self.rng.choices(model_list, k=2) # Starting with arch param tree dict, randomly replaces DiscreteChoice objects # with params from model_1 with probability 0.5 choices_dict = self.arch_param_tree.to_dict() cross_dict = utils.replace_ptree_pair_choices( choices_dict, model_1.metadata["config"].to_dict(), lambda d_choice, m1_value: (m1_value if self.rng.random() < 0.5 else d_choice), ) # Replaces all remaining DiscreteChoice objects with params from model_2 cross_dict = utils.replace_ptree_pair_choices( cross_dict, model_2.metadata["config"].to_dict(), lambda d_choice, m2_value: m2_value ) cross_config = build_arch_config(cross_dict) cross_model = self.model_cls(cross_config, **self.model_kwargs) return ArchaiModel(arch=cross_model, archid=self.get_archid(cross_config), metadata={"config": cross_config})
[docs] @overrides def encode(self, model: ArchaiModel) -> np.ndarray: encoded_config = np.array( self.arch_param_tree.encode_config( model.metadata["config"], track_unused_params=self.track_unused_params ) ) return np.nan_to_num(encoded_config, nan=self.unused_param_value)