Coverage for mlos_core/mlos_core/spaces/converters/flaml.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Contains space converters for :py:class:`~mlos_core.optimizers.flaml_optimizer`""" 

6 

7import sys 

8from typing import TYPE_CHECKING, Dict 

9 

10import ConfigSpace 

11import flaml.tune 

12import flaml.tune.sample 

13import numpy as np 

14 

15if TYPE_CHECKING: 

16 from ConfigSpace.hyperparameters import Hyperparameter 

17 

18if sys.version_info >= (3, 10): 

19 from typing import TypeAlias 

20else: 

21 from typing_extensions import TypeAlias 

22 

23 

24FlamlDomain: TypeAlias = flaml.tune.sample.Domain 

25"""Flaml domain type alias.""" 

26 

27FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] 

28"""Flaml space type alias - a `Dict[str, FlamlDomain]`""" 

29 

30 

31def configspace_to_flaml_space( 

32 config_space: ConfigSpace.ConfigurationSpace, 

33) -> Dict[str, FlamlDomain]: 

34 """ 

35 Converts a ConfigSpace.ConfigurationSpace to dict. 

36 

37 Parameters 

38 ---------- 

39 config_space : ConfigSpace.ConfigurationSpace 

40 Input configuration space. 

41 

42 Returns 

43 ------- 

44 flaml_space : dict 

45 A dictionary of flaml.tune.sample.Domain objects keyed by parameter name. 

46 """ 

47 flaml_numeric_type = { 

48 (ConfigSpace.UniformIntegerHyperparameter, False): flaml.tune.randint, 

49 (ConfigSpace.UniformIntegerHyperparameter, True): flaml.tune.lograndint, 

50 (ConfigSpace.UniformFloatHyperparameter, False): flaml.tune.uniform, 

51 (ConfigSpace.UniformFloatHyperparameter, True): flaml.tune.loguniform, 

52 } 

53 

54 def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: 

55 if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): 

56 # FIXME: upper isn't included in the range 

57 return flaml_numeric_type[(type(parameter), parameter.log)]( 

58 parameter.lower, 

59 parameter.upper, 

60 ) 

61 elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): 

62 return flaml_numeric_type[(type(parameter), parameter.log)]( 

63 parameter.lower, 

64 parameter.upper + 1, 

65 ) 

66 elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): 

67 if len(np.unique(parameter.probabilities)) > 1: 

68 raise ValueError( 

69 "FLAML doesn't support categorical parameters with non-uniform probabilities." 

70 ) 

71 return flaml.tune.choice(parameter.choices) # TODO: set order? 

72 raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") 

73 

74 return {param.name: _one_parameter_convert(param) for param in config_space.values()}