Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6An interface to access the tunable config trial group data stored in SQL DB. 

7""" 

8 

9from typing import Dict, Optional, TYPE_CHECKING 

10 

11import pandas 

12from sqlalchemy import Engine, Integer, func 

13 

14from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

15from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData 

16from mlos_bench.storage.sql import common 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

19 

20if TYPE_CHECKING: 

21 from mlos_bench.storage.base_trial_data import TrialData 

22 

23 

24class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): 

25 """ 

26 SQL interface for accessing the stored experiment benchmark tunable config 

27 trial group data. 

28 

29 A (tunable) config is used to define an instance of values for a set of tunable 

30 parameters for a given experiment and can be used by one or more trial instances 

31 (e.g., for repeats), which we call a (tunable) config trial group. 

32 """ 

33 

34 def __init__(self, *, 

35 engine: Engine, 

36 schema: DbSchema, 

37 experiment_id: str, 

38 tunable_config_id: int, 

39 tunable_config_trial_group_id: Optional[int] = None): 

40 super().__init__( 

41 experiment_id=experiment_id, 

42 tunable_config_id=tunable_config_id, 

43 tunable_config_trial_group_id=tunable_config_trial_group_id, 

44 ) 

45 self._engine = engine 

46 self._schema = schema 

47 

48 def _get_tunable_config_trial_group_id(self) -> int: 

49 """ 

50 Retrieve the trial's tunable_config_trial_group_id from the storage. 

51 """ 

52 with self._engine.connect() as conn: 

53 tunable_config_trial_group = conn.execute( 

54 self._schema.trial.select().with_only_columns( 

55 func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable 

56 'tunable_config_trial_group_id'), 

57 ).where( 

58 self._schema.trial.c.exp_id == self._experiment_id, 

59 self._schema.trial.c.config_id == self._tunable_config_id, 

60 ).group_by( 

61 self._schema.trial.c.exp_id, 

62 self._schema.trial.c.config_id, 

63 ) 

64 ) 

65 row = tunable_config_trial_group.fetchone() 

66 assert row is not None 

67 return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

68 

69 @property 

70 def tunable_config(self) -> TunableConfigData: 

71 return TunableConfigSqlData( 

72 engine=self._engine, 

73 schema=self._schema, 

74 tunable_config_id=self.tunable_config_id, 

75 ) 

76 

77 @property 

78 def trials(self) -> Dict[int, "TrialData"]: 

79 """ 

80 Retrieve the trials' data for this (tunable) config trial group from the storage. 

81 

82 Returns 

83 ------- 

84 trials : Dict[int, TrialData] 

85 A dictionary of the trials' data, keyed by trial id. 

86 """ 

87 return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) 

88 

89 @property 

90 def results_df(self) -> pandas.DataFrame: 

91 return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id)