Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%

56 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Saving and updating benchmark data using SQLAlchemy backend. 

7""" 

8 

9import logging 

10from datetime import datetime 

11from typing import List, Optional, Tuple, Union, Dict, Any 

12 

13from sqlalchemy import Engine, Connection 

14from sqlalchemy.exc import IntegrityError 

15 

16from mlos_bench.environments.status import Status 

17from mlos_bench.tunables.tunable_groups import TunableGroups 

18from mlos_bench.storage.base_storage import Storage 

19from mlos_bench.storage.sql.schema import DbSchema 

20from mlos_bench.util import nullable, utcify_timestamp 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class Trial(Storage.Trial): 

26 """ 

27 Store the results of a single run of the experiment in SQL database. 

28 """ 

29 

30 def __init__(self, *, 

31 engine: Engine, schema: DbSchema, tunables: TunableGroups, 

32 experiment_id: str, trial_id: int, config_id: int, 

33 opt_target: str, opt_direction: Optional[str], 

34 config: Optional[Dict[str, Any]] = None): 

35 super().__init__( 

36 tunables=tunables, 

37 experiment_id=experiment_id, 

38 trial_id=trial_id, 

39 tunable_config_id=config_id, 

40 opt_target=opt_target, 

41 opt_direction=opt_direction, 

42 config=config, 

43 ) 

44 self._engine = engine 

45 self._schema = schema 

46 

47 def update(self, status: Status, timestamp: datetime, 

48 metrics: Optional[Union[Dict[str, Any], float]] = None 

49 ) -> Optional[Dict[str, Any]]: 

50 # Make sure to convert the timestamp to UTC before storing it in the database. 

51 timestamp = utcify_timestamp(timestamp, origin="local") 

52 metrics = super().update(status, timestamp, metrics) 

53 with self._engine.begin() as conn: 

54 self._update_status(conn, status, timestamp) 

55 try: 

56 if status.is_completed(): 

57 # Final update of the status and ts_end: 

58 cur_status = conn.execute( 

59 self._schema.trial.update().where( 

60 self._schema.trial.c.exp_id == self._experiment_id, 

61 self._schema.trial.c.trial_id == self._trial_id, 

62 self._schema.trial.c.ts_end.is_(None), 

63 self._schema.trial.c.status.notin_( 

64 ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), 

65 ).values( 

66 status=status.name, 

67 ts_end=timestamp, 

68 ) 

69 ) 

70 if cur_status.rowcount not in {1, -1}: 

71 _LOG.warning("Trial %s :: update failed: %s", self, status) 

72 raise RuntimeError( 

73 f"Failed to update the status of the trial {self} to {status}." + 

74 f" ({cur_status.rowcount} rows)") 

75 if metrics: 

76 conn.execute(self._schema.trial_result.insert().values([ 

77 { 

78 "exp_id": self._experiment_id, 

79 "trial_id": self._trial_id, 

80 "metric_id": key, 

81 "metric_value": nullable(str, val), 

82 } 

83 for (key, val) in metrics.items() 

84 ])) 

85 else: 

86 # Update of the status and ts_start when starting the trial: 

87 assert metrics is None, f"Unexpected metrics for status: {status}" 

88 cur_status = conn.execute( 

89 self._schema.trial.update().where( 

90 self._schema.trial.c.exp_id == self._experiment_id, 

91 self._schema.trial.c.trial_id == self._trial_id, 

92 self._schema.trial.c.ts_end.is_(None), 

93 self._schema.trial.c.status.notin_( 

94 ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), 

95 ).values( 

96 status=status.name, 

97 ts_start=timestamp, 

98 ) 

99 ) 

100 if cur_status.rowcount not in {1, -1}: 

101 # Keep the old status and timestamp if already running, but log it. 

102 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

103 except Exception: 

104 conn.rollback() 

105 raise 

106 return metrics 

107 

108 def update_telemetry(self, status: Status, timestamp: datetime, 

109 metrics: List[Tuple[datetime, str, Any]]) -> None: 

110 super().update_telemetry(status, timestamp, metrics) 

111 # Make sure to convert the timestamp to UTC before storing it in the database. 

112 timestamp = utcify_timestamp(timestamp, origin="local") 

113 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

114 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

115 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

116 # a bulk upsert. 

117 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

118 with self._engine.begin() as conn: 

119 self._update_status(conn, status, timestamp) 

120 for (metric_ts, key, val) in metrics: 

121 with self._engine.begin() as conn: 

122 try: 

123 conn.execute(self._schema.trial_telemetry.insert().values( 

124 exp_id=self._experiment_id, 

125 trial_id=self._trial_id, 

126 ts=metric_ts, 

127 metric_id=key, 

128 metric_value=nullable(str, val), 

129 )) 

130 except IntegrityError as ex: 

131 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

132 

133 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

134 """ 

135 Insert a new status record into the database. 

136 This call is idempotent. 

137 """ 

138 # Make sure to convert the timestamp to UTC before storing it in the database. 

139 timestamp = utcify_timestamp(timestamp, origin="local") 

140 try: 

141 conn.execute(self._schema.trial_status.insert().values( 

142 exp_id=self._experiment_id, 

143 trial_id=self._trial_id, 

144 ts=timestamp, 

145 status=status.name, 

146 )) 

147 except IntegrityError as ex: 

148 _LOG.warning("Status with that timestamp already exists: %s %s :: %s", 

149 self, timestamp, ex)