Coverage for mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py: 100%
24 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for converting tunable parameters with explicitly
7specified distributions to ConfigSpace.
8"""
10import pytest
12from ConfigSpace import (
13 CategoricalHyperparameter,
14 BetaFloatHyperparameter,
15 BetaIntegerHyperparameter,
16 NormalFloatHyperparameter,
17 NormalIntegerHyperparameter,
18 UniformFloatHyperparameter,
19 UniformIntegerHyperparameter,
20)
22from mlos_bench.tunables.tunable import DistributionName
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.optimizers.convert_configspace import (
25 special_param_names,
26 tunable_groups_to_configspace,
27)
30_CS_HYPERPARAMETER = {
31 ("float", "beta"): BetaFloatHyperparameter,
32 ("int", "beta"): BetaIntegerHyperparameter,
33 ("float", "normal"): NormalFloatHyperparameter,
34 ("int", "normal"): NormalIntegerHyperparameter,
35 ("float", "uniform"): UniformFloatHyperparameter,
36 ("int", "uniform"): UniformIntegerHyperparameter,
37}
40@pytest.mark.parametrize("param_type", ["int", "float"])
41@pytest.mark.parametrize("distr_name,distr_params", [
42 ("normal", {"mu": 0.0, "sigma": 1.0}),
43 ("beta", {"alpha": 2, "beta": 5}),
44 ("uniform", {}),
45])
46def test_convert_numerical_distributions(param_type: str,
47 distr_name: DistributionName,
48 distr_params: dict) -> None:
49 """
50 Convert a numerical Tunable with explicit distribution to ConfigSpace.
51 """
52 tunable_name = "x"
53 tunable_groups = TunableGroups({
54 "tunable_group": {
55 "cost": 1,
56 "params": {
57 tunable_name: {
58 "type": param_type,
59 "range": [0, 100],
60 "special": [-1, 0],
61 "special_weights": [0.1, 0.2],
62 "range_weight": 0.7,
63 "distribution": {
64 "type": distr_name,
65 "params": distr_params
66 },
67 "default": 0
68 }
69 }
70 }
71 })
73 (tunable, _group) = tunable_groups.get_tunable(tunable_name)
74 assert tunable.distribution == distr_name
75 assert tunable.distribution_params == distr_params
77 space = tunable_groups_to_configspace(tunable_groups)
79 (tunable_special, tunable_type) = special_param_names(tunable_name)
80 assert set(space.keys()) == {tunable_name, tunable_type, tunable_special}
82 assert isinstance(space[tunable_special], CategoricalHyperparameter)
83 assert isinstance(space[tunable_type], CategoricalHyperparameter)
85 cs_param = space[tunable_name]
86 assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name])
87 for (key, val) in distr_params.items():
88 assert getattr(cs_param, key) == val