Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%
30 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Helper functions for config space converters."""
7from ConfigSpace import ConfigurationSpace
8from ConfigSpace.functional import quantize
9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter
11QUANTIZATION_BINS_META_KEY = "quantization_bins"
14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter:
15 """
16 Monkey-patch quantization into the Hyperparameter.
18 Temporary workaround to dropped quantization support in ConfigSpace 1.0
20 Notes
21 -----
22 See <https://github.com/automl/ConfigSpace/issues/390>.
24 Parameters
25 ----------
26 hp : ConfigSpace.hyperparameters.Hyperparameter
27 ConfigSpace hyperparameter to patch.
29 Returns
30 -------
31 hp : ConfigSpace.hyperparameters.Hyperparameter
32 Patched hyperparameter.
33 """
34 if not isinstance(hp, NumericalHyperparameter):
35 return hp
37 assert isinstance(hp, NumericalHyperparameter)
38 dist = hp._vector_dist # pylint: disable=protected-access
39 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY)
40 if quantization_bins is None:
41 # No quantization requested.
42 # Remove any previously applied patches.
43 if hasattr(dist, "sample_vector_mlos_orig"):
44 setattr(dist, "sample_vector", dist.sample_vector_mlos_orig)
45 delattr(dist, "sample_vector_mlos_orig")
46 return hp
48 try:
49 quantization_bins = int(quantization_bins)
50 except ValueError as ex:
51 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex
53 if quantization_bins <= 1:
54 raise ValueError(f"{quantization_bins=} :: must be greater than 1.")
56 if not hasattr(dist, "sample_vector_mlos_orig"):
57 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector)
59 assert hasattr(dist, "sample_vector_mlos_orig")
60 setattr(
61 dist,
62 "sample_vector",
63 lambda n, *, seed=None: quantize(
64 dist.sample_vector_mlos_orig(n, seed=seed),
65 bounds=(dist.lower_vectorized, dist.upper_vectorized),
66 bins=quantization_bins,
67 ),
68 )
69 return hp
72def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace:
73 """
74 Monkey-patch quantization into the Hyperparameters of a ConfigSpace.
76 Parameters
77 ----------
78 cs : ConfigSpace.ConfigurationSpace
79 ConfigSpace to patch.
81 Returns
82 -------
83 cs : ConfigSpace.ConfigurationSpace
84 Patched ConfigSpace.
85 """
86 for hp in cs.values():
87 monkey_patch_hp_quantization(hp)
88 return cs