Coverage for mlos_core/mlos_core/tests/spaces/monkey_patch_quantization_test.py: 100%
55 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for ConfigSpace quantization monkey patching."""
7import numpy as np
8from ConfigSpace import (
9 ConfigurationSpace,
10 UniformFloatHyperparameter,
11 UniformIntegerHyperparameter,
12)
13from numpy.random import RandomState
15from mlos_core.spaces.converters.util import (
16 QUANTIZATION_BINS_META_KEY,
17 monkey_patch_cs_quantization,
18 monkey_patch_hp_quantization,
19)
20from mlos_core.tests import SEED
23def test_configspace_quant_int() -> None:
24 """Check the quantization of an integer hyperparameter."""
25 quantization_bins = 11
26 quantized_values = set(range(0, 101, 10))
27 hp = UniformIntegerHyperparameter(
28 "hp",
29 lower=0,
30 upper=100,
31 log=False,
32 meta={QUANTIZATION_BINS_META_KEY: quantization_bins},
33 )
35 # Before patching: expect that at least one value is not quantized.
36 assert not set(hp.sample_value(100)).issubset(quantized_values)
38 monkey_patch_hp_quantization(hp)
39 # After patching: *all* values must belong to the set of quantized values.
40 assert hp.sample_value() in quantized_values # check scalar type
41 assert set(hp.sample_value(100)).issubset(quantized_values) # batch version
44def test_configspace_quant_float() -> None:
45 """Check the quantization of a float hyperparameter."""
46 # 5 is a nice number of bins to avoid floating point errors.
47 quantization_bins = 5
48 quantized_values = set(np.linspace(0, 1, num=quantization_bins, endpoint=True))
49 hp = UniformFloatHyperparameter(
50 "hp",
51 lower=0,
52 upper=1,
53 log=False,
54 meta={QUANTIZATION_BINS_META_KEY: quantization_bins},
55 )
57 # Before patching: expect that at least one value is not quantized.
58 assert not set(hp.sample_value(100)).issubset(quantized_values)
60 monkey_patch_hp_quantization(hp)
61 # After patching: *all* values must belong to the set of quantized values.
62 assert hp.sample_value() in quantized_values # check scalar type
63 assert set(hp.sample_value(100)).issubset(quantized_values) # batch version
66def test_configspace_quant_repatch() -> None:
67 """Repatch the same hyperparameter with different number of bins."""
68 quantization_bins = 11
69 quantized_values = set(range(0, 101, 10))
70 hp = UniformIntegerHyperparameter(
71 "hp",
72 lower=0,
73 upper=100,
74 log=False,
75 meta={QUANTIZATION_BINS_META_KEY: quantization_bins},
76 )
78 # Before patching: expect that at least one value is not quantized.
79 assert not set(hp.sample_value(100)).issubset(quantized_values)
81 monkey_patch_hp_quantization(hp)
82 # After patching: *all* values must belong to the set of quantized values.
83 samples = hp.sample_value(100, seed=RandomState(SEED))
84 assert set(samples).issubset(quantized_values)
86 # Patch the same hyperparameter again and check that the results are the same.
87 monkey_patch_hp_quantization(hp)
88 # After patching: *all* values must belong to the set of quantized values.
89 assert all(samples == hp.sample_value(100, seed=RandomState(SEED)))
91 # Repatch with the higher number of bins and make sure we get new values.
92 new_meta = dict(hp.meta or {})
93 new_meta[QUANTIZATION_BINS_META_KEY] = 21
94 hp.meta = new_meta
95 monkey_patch_hp_quantization(hp)
96 samples_set = set(hp.sample_value(100, seed=RandomState(SEED)))
97 quantized_values_new = set(range(5, 96, 10))
98 assert samples_set.issubset(set(range(0, 101, 5)))
99 assert len(samples_set - quantized_values_new) < len(samples_set)
101 # Repatch without quantization and make sure we get the original values.
102 new_meta = dict(hp.meta or {})
103 del new_meta[QUANTIZATION_BINS_META_KEY]
104 hp.meta = new_meta
105 assert hp.meta.get(QUANTIZATION_BINS_META_KEY) is None
106 monkey_patch_hp_quantization(hp)
107 samples_set = set(hp.sample_value(100, seed=RandomState(SEED)))
108 assert samples_set.issubset(set(range(0, 101)))
109 assert len(quantized_values_new) < len(quantized_values) < len(samples_set)
112def test_configspace_quant() -> None:
113 """Test quantization of multiple hyperparameters in the ConfigSpace."""
114 space = ConfigurationSpace(
115 name="cs_test",
116 space={
117 "hp_int": (0, 100000),
118 "hp_int_quant": (0, 100000),
119 "hp_float": (0.0, 1.0),
120 "hp_categorical": ["a", "b", "c"],
121 "hp_constant": 1337,
122 },
123 )
124 space["hp_int_quant"].meta = {QUANTIZATION_BINS_META_KEY: 5}
125 space["hp_float"].meta = {QUANTIZATION_BINS_META_KEY: 11}
126 monkey_patch_cs_quantization(space)
128 space.seed(SEED)
129 assert dict(space.sample_configuration()) == {
130 "hp_categorical": "c",
131 "hp_constant": 1337,
132 "hp_float": 0.6,
133 "hp_int": 60263,
134 "hp_int_quant": 0,
135 }
136 assert [dict(conf) for conf in space.sample_configuration(3)] == [
137 {
138 "hp_categorical": "a",
139 "hp_constant": 1337,
140 "hp_float": 0.4,
141 "hp_int": 59150,
142 "hp_int_quant": 50000,
143 },
144 {
145 "hp_categorical": "a",
146 "hp_constant": 1337,
147 "hp_float": 0.3,
148 "hp_int": 65725,
149 "hp_int_quant": 75000,
150 },
151 {
152 "hp_categorical": "b",
153 "hp_constant": 1337,
154 "hp_float": 0.6,
155 "hp_int": 84654,
156 "hp_int_quant": 25000,
157 },
158 ]