Coverage for mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for converting tunable parameters with explicitly 

7specified distributions to ConfigSpace. 

8""" 

9 

10import pytest 

11 

12from ConfigSpace import ( 

13 CategoricalHyperparameter, 

14 BetaFloatHyperparameter, 

15 BetaIntegerHyperparameter, 

16 NormalFloatHyperparameter, 

17 NormalIntegerHyperparameter, 

18 UniformFloatHyperparameter, 

19 UniformIntegerHyperparameter, 

20) 

21 

22from mlos_bench.tunables.tunable import DistributionName 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.optimizers.convert_configspace import ( 

25 special_param_names, 

26 tunable_groups_to_configspace, 

27) 

28 

29 

30_CS_HYPERPARAMETER = { 

31 ("float", "beta"): BetaFloatHyperparameter, 

32 ("int", "beta"): BetaIntegerHyperparameter, 

33 ("float", "normal"): NormalFloatHyperparameter, 

34 ("int", "normal"): NormalIntegerHyperparameter, 

35 ("float", "uniform"): UniformFloatHyperparameter, 

36 ("int", "uniform"): UniformIntegerHyperparameter, 

37} 

38 

39 

40@pytest.mark.parametrize("param_type", ["int", "float"]) 

41@pytest.mark.parametrize("distr_name,distr_params", [ 

42 ("normal", {"mu": 0.0, "sigma": 1.0}), 

43 ("beta", {"alpha": 2, "beta": 5}), 

44 ("uniform", {}), 

45]) 

46def test_convert_numerical_distributions(param_type: str, 

47 distr_name: DistributionName, 

48 distr_params: dict) -> None: 

49 """ 

50 Convert a numerical Tunable with explicit distribution to ConfigSpace. 

51 """ 

52 tunable_name = "x" 

53 tunable_groups = TunableGroups({ 

54 "tunable_group": { 

55 "cost": 1, 

56 "params": { 

57 tunable_name: { 

58 "type": param_type, 

59 "range": [0, 100], 

60 "special": [-1, 0], 

61 "special_weights": [0.1, 0.2], 

62 "range_weight": 0.7, 

63 "distribution": { 

64 "type": distr_name, 

65 "params": distr_params 

66 }, 

67 "default": 0 

68 } 

69 } 

70 } 

71 }) 

72 

73 (tunable, _group) = tunable_groups.get_tunable(tunable_name) 

74 assert tunable.distribution == distr_name 

75 assert tunable.distribution_params == distr_params 

76 

77 space = tunable_groups_to_configspace(tunable_groups) 

78 

79 (tunable_special, tunable_type) = special_param_names(tunable_name) 

80 assert set(space.keys()) == {tunable_name, tunable_type, tunable_special} 

81 

82 assert isinstance(space[tunable_special], CategoricalHyperparameter) 

83 assert isinstance(space[tunable_type], CategoricalHyperparameter) 

84 

85 cs_param = space[tunable_name] 

86 assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) 

87 for (key, val) in distr_params.items(): 

88 assert getattr(cs_param, key) == val