Source code for archai.discrete_search.api.search_space

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import abstractmethod
from typing import List

import numpy as np
from overrides import EnforceOverrides

from archai.discrete_search.api.archai_model import ArchaiModel


[docs]class DiscreteSearchSpace(EnforceOverrides): """Abstract class for discrete search spaces. This class serves as a base for implementing search spaces. The class enforces implementation of five methods: `save_arch`, `load_arch`, `save_model_weights`, `load_model_weights` and `random_sample`. Note: This class is inherited from `EnforceOverrides` and any overridden methods in the subclass should be decorated with `@overrides` to ensure they are properly overridden. Examples: >>> class MyDiscreteSearchSpace(DiscreteSearchSpace): >>> def __init__(self) -> None: >>> super().__init__() >>> >>> @overrides >>> def save_arch(self, arch, file_path) -> None: >>> torch.save(arch, file_path) >>> >>> @overrides >>> def load_arch(self, file_path) -> ArchaiModel: >>> return torch.load(file_path) >>> >>> @overrides >>> def save_model_weights(self, model, file_path) -> None: >>> torch.save(model.state_dict(), file_path) >>> >>> @overrides >>> def load_model_weights(self, model, file_path) -> None: >>> model.load_state_dict(torch.load(file_path)) >>> >>> @overrides >>> def random_sample(self, config) -> ArchaiModel: >>> return ArchaiModel(config) """
[docs] @abstractmethod def save_arch(self, model: ArchaiModel, file_path: str) -> None: """Save an architecture to a file without saving the weights. Args: model: Model's architecture to save. file_path: File path to save the architecture. """ pass
[docs] @abstractmethod def load_arch(self, file_path: str) -> ArchaiModel: """Load from a file an architecture that was saved using `SearchSpace.save_arch()`. Args: file_path: File path to load the architecture. Returns: Loaded model. """ pass
[docs] @abstractmethod def save_model_weights(self, model: ArchaiModel, file_path: str) -> None: """Save the weights of a model. Args: model: Model to save the weights. file_path: File path to save the weights. """ pass
[docs] @abstractmethod def load_model_weights(self, model: ArchaiModel, file_path: str) -> None: """Load the weights (created with `SearchSpace.save_model_weights()`) into a model of the same architecture. Args: model: Model to load the weights. file_path: File path to load the weights. """ pass
[docs] @abstractmethod def random_sample(self) -> ArchaiModel: """Randomly sample an architecture from the search spaces. Returns: Sampled architecture. """ pass
[docs]class EvolutionarySearchSpace(DiscreteSearchSpace, EnforceOverrides): """Abstract class for discrete search spaces compatible with evolutionary algorithms. The class enforces implementation of two methods: `mutate` and `crossover`. Note: This class is inherited from `EnforceOverrides` and any overridden methods in the subclass should be decorated with `@overrides` to ensure they are properly overridden. """
[docs] @abstractmethod def mutate(self, arch: ArchaiModel) -> ArchaiModel: """Mutate an architecture from the search space. This method should not alter the base model architecture directly, only generate a new one. Args: arch: Base model. Returns: Mutated model. """ pass
[docs] @abstractmethod def crossover(self, arch_list: List[ArchaiModel]) -> ArchaiModel: """Combine a list of architectures into a new one. Args: arch_list: List of architectures. Returns: Resulting model. """ pass
[docs]class BayesOptSearchSpace(DiscreteSearchSpace, EnforceOverrides): """Abstract class for discrete search spaces compatible with Bayesian Optimization algorithms. The class enforces implementation of a single method: `encode`. Note: This class is inherited from `EnforceOverrides` and any overridden methods in the subclass should be decorated with `@overrides` to ensure they are properly overridden. """
[docs] @abstractmethod def encode(self, arch: ArchaiModel) -> np.ndarray: """Encode an architecture into a fixed-length vector representation. Args: arch: Model from the search space. Returns: Fixed-length vector representation of `arch`. """ pass