Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 96%
24 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Contains space converters for :py:class:`~mlos_core.optimizers.flaml_optimizer`"""
7import sys
8from typing import TYPE_CHECKING, Dict
10import ConfigSpace
11import flaml.tune
12import flaml.tune.sample
13import numpy as np
15if TYPE_CHECKING:
16 from ConfigSpace.hyperparameters import Hyperparameter
18if sys.version_info >= (3, 10):
19 from typing import TypeAlias
20else:
21 from typing_extensions import TypeAlias
24FlamlDomain: TypeAlias = flaml.tune.sample.Domain
25"""Flaml domain type alias."""
27FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain]
28"""Flaml space type alias - a `Dict[str, FlamlDomain]`"""
31def configspace_to_flaml_space(
32 config_space: ConfigSpace.ConfigurationSpace,
33) -> Dict[str, FlamlDomain]:
34 """
35 Converts a ConfigSpace.ConfigurationSpace to dict.
37 Parameters
38 ----------
39 config_space : ConfigSpace.ConfigurationSpace
40 Input configuration space.
42 Returns
43 -------
44 flaml_space : dict
45 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name.
46 """
47 flaml_numeric_type = {
48 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint,
49 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint,
50 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform,
51 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform,
52 }
54 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain:
55 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter):
56 # FIXME: upper isn't included in the range
57 return flaml_numeric_type[(type(parameter), parameter.log)](
58 parameter.lower,
59 parameter.upper,
60 )
61 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter):
62 return flaml_numeric_type[(type(parameter), parameter.log)](
63 parameter.lower,
64 parameter.upper + 1,
65 )
66 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter):
67 if len(np.unique(parameter.probabilities)) > 1:
68 raise ValueError(
69 "FLAML doesn't support categorical parameters with non-uniform probabilities."
70 )
71 return flaml.tune.choice(parameter.choices) # TODO: set order?
72 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.")
74 return {param.name: _one_parameter_convert(param) for param in config_space.values()}