Source code for archai.discrete_search.search_spaces.config.discrete_choice
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from random import Random
from numbers import Number
from typing import Any, List, Union, Optional
[docs]class DiscreteChoice:
def __init__(self, choices: List[Union[int, float, str]],
probabilities: Optional[List[float]] = None,
encode_strategy: str = 'auto') -> None:
""" Stores a discrete choice of numeric or non-numeric values.
The choice can be encoded as a numeric value or using one-hot encoding depending on the
value passed to `encode_strategy`.
Args:
choices (List[Union[int, float, str]]): List of choices. Choices can be integers,
floats or strings.
probabilities (Optional[List[float]], optional): Probability distribution of each choice
used during sampling. If `None`, a uniform distribution is used.
encode_strategy (str, optional): Encoding strategy to use ['one_hot', 'numeric']. If
'auto', the encoding strategy is chosen based on the type of the choices.
Defaults to 'auto'.
"""
self.choices = choices
self.probabilities = probabilities
if encode_strategy == 'auto':
encode_strategy = (
'numeric' if all(isinstance(choice, Number) for choice in choices)
else 'one_hot'
)
self.encode_strategy = encode_strategy
def __getitem__(self, idx: str) -> Any:
return self.choices[idx]
def __repr__(self) -> str:
return f"DiscreteChoice({repr(self.choices)})"
def __str__(self) -> str:
return self.__repr__()
def __len__(self) -> int:
return len(self.choices)
[docs] def encode(self, option: Any) -> List[float]:
"""Encodes the option into a numeric value or a one-hot encoding.
Args:
option (Any): Option to encode.
Returns:
List[float]: Encoded option.
"""
if self.encode_strategy == 'one_hot':
assert option in self.choices, f'Invalid option: {option}. Valid options: {self.choices}'
return [float(choice == option) for choice in self.choices]
return [float(option)]
[docs] def random_sample(self, rng: Optional[Random] = None) -> Any:
"""Randomly samples a choice from the discrete set.
Args:
rng (Optional[Random], optional): Random number generator.
Returns:
Any: Randomly sampled choice.
"""
rng = rng or Random()
return rng.choices(self.choices, weights=self.probabilities, k=1)[0]