Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Saving and restoring the benchmark data in SQL database. 

7""" 

8 

9import logging 

10from typing import Dict, Optional 

11 

12from sqlalchemy import URL, create_engine 

13 

14from mlos_bench.tunables.tunable_groups import TunableGroups 

15from mlos_bench.services.base_service import Service 

16from mlos_bench.storage.base_storage import Storage 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.experiment import Experiment 

19from mlos_bench.storage.base_experiment_data import ExperimentData 

20from mlos_bench.storage.sql.experiment_data import ExperimentSqlData 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class SqlStorage(Storage): 

26 """ 

27 An implementation of the Storage interface using SQLAlchemy backend. 

28 """ 

29 

30 def __init__(self, 

31 config: dict, 

32 global_config: Optional[dict] = None, 

33 service: Optional[Service] = None): 

34 super().__init__(config, global_config, service) 

35 lazy_schema_create = self._config.pop("lazy_schema_create", False) 

36 self._log_sql = self._config.pop("log_sql", False) 

37 self._url = URL.create(**self._config) 

38 self._repr = f"{self._url.get_backend_name()}:{self._url.database}" 

39 _LOG.info("Connect to the database: %s", self) 

40 self._engine = create_engine(self._url, echo=self._log_sql) 

41 self._db_schema: DbSchema 

42 if not lazy_schema_create: 

43 assert self._schema 

44 else: 

45 _LOG.info("Using lazy schema create for database: %s", self) 

46 

47 @property 

48 def _schema(self) -> DbSchema: 

49 """Lazily create schema upon first access.""" 

50 if not hasattr(self, '_db_schema'): 

51 self._db_schema = DbSchema(self._engine).create() 

52 if _LOG.isEnabledFor(logging.DEBUG): 

53 _LOG.debug("DDL statements:\n%s", self._schema) 

54 return self._db_schema 

55 

56 def __repr__(self) -> str: 

57 return self._repr 

58 

59 def experiment(self, *, 

60 experiment_id: str, 

61 trial_id: int, 

62 root_env_config: str, 

63 description: str, 

64 tunables: TunableGroups, 

65 opt_target: str, 

66 opt_direction: Optional[str]) -> Storage.Experiment: 

67 return Experiment( 

68 engine=self._engine, 

69 schema=self._schema, 

70 tunables=tunables, 

71 experiment_id=experiment_id, 

72 trial_id=trial_id, 

73 root_env_config=root_env_config, 

74 description=description, 

75 opt_target=opt_target, 

76 opt_direction=opt_direction, 

77 ) 

78 

79 @property 

80 def experiments(self) -> Dict[str, ExperimentData]: 

81 # FIXME: this is somewhat expensive if only fetching a single Experiment. 

82 # May need to expand the API or data structures to lazily fetch data and/or cache it. 

83 with self._engine.connect() as conn: 

84 cur_exp = conn.execute( 

85 self._schema.experiment.select().order_by( 

86 self._schema.experiment.c.exp_id.asc(), 

87 ) 

88 ) 

89 return { 

90 exp.exp_id: ExperimentSqlData( 

91 engine=self._engine, 

92 schema=self._schema, 

93 experiment_id=exp.exp_id, 

94 description=exp.description, 

95 root_env_config=exp.root_env_config, 

96 git_repo=exp.git_repo, 

97 git_commit=exp.git_commit, 

98 ) 

99 for exp in cur_exp.fetchall() 

100 }