Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Helper functions for config space converters.""" 

6 

7from ConfigSpace import ConfigurationSpace 

8from ConfigSpace.functional import quantize 

9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter 

10 

11QUANTIZATION_BINS_META_KEY = "quantization_bins" 

12 

13 

14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter: 

15 """ 

16 Monkey-patch quantization into the Hyperparameter. 

17 

18 Temporary workaround to dropped quantization support in ConfigSpace 1.0 

19 

20 Notes 

21 ----- 

22 See <https://github.com/automl/ConfigSpace/issues/390>. 

23 

24 Parameters 

25 ---------- 

26 hp : ConfigSpace.hyperparameters.Hyperparameter 

27 ConfigSpace hyperparameter to patch. 

28 

29 Returns 

30 ------- 

31 hp : ConfigSpace.hyperparameters.Hyperparameter 

32 Patched hyperparameter. 

33 """ 

34 if not isinstance(hp, NumericalHyperparameter): 

35 return hp 

36 

37 assert isinstance(hp, NumericalHyperparameter) 

38 dist = hp._vector_dist # pylint: disable=protected-access 

39 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY) 

40 if quantization_bins is None: 

41 # No quantization requested. 

42 # Remove any previously applied patches. 

43 if hasattr(dist, "sample_vector_mlos_orig"): 

44 setattr(dist, "sample_vector", dist.sample_vector_mlos_orig) 

45 delattr(dist, "sample_vector_mlos_orig") 

46 return hp 

47 

48 try: 

49 quantization_bins = int(quantization_bins) 

50 except ValueError as ex: 

51 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex 

52 

53 if quantization_bins <= 1: 

54 raise ValueError(f"{quantization_bins=} :: must be greater than 1.") 

55 

56 if not hasattr(dist, "sample_vector_mlos_orig"): 

57 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector) 

58 

59 assert hasattr(dist, "sample_vector_mlos_orig") 

60 setattr( 

61 dist, 

62 "sample_vector", 

63 lambda n, *, seed=None: quantize( 

64 dist.sample_vector_mlos_orig(n, seed=seed), 

65 bounds=(dist.lower_vectorized, dist.upper_vectorized), 

66 bins=quantization_bins, 

67 ), 

68 ) 

69 return hp 

70 

71 

72def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace: 

73 """ 

74 Monkey-patch quantization into the Hyperparameters of a ConfigSpace. 

75 

76 Parameters 

77 ---------- 

78 cs : ConfigSpace.ConfigurationSpace 

79 ConfigSpace to patch. 

80 

81 Returns 

82 ------- 

83 cs : ConfigSpace.ConfigurationSpace 

84 Patched ConfigSpace. 

85 """ 

86 for hp in cs.values(): 

87 monkey_patch_hp_quantization(hp) 

88 return cs