Coverage for mlos_bench/mlos_bench/util.py: 89%

110 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Various helper functions for mlos_bench. 

7""" 

8 

9# NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. 

10 

11from datetime import datetime 

12import os 

13import json 

14import logging 

15import importlib 

16import subprocess 

17 

18from typing import ( 

19 Any, Callable, Dict, Iterable, Literal, Mapping, Optional, 

20 Tuple, Type, TypeVar, TYPE_CHECKING, Union, 

21) 

22 

23import pandas 

24import pytz 

25 

26 

27_LOG = logging.getLogger(__name__) 

28 

29if TYPE_CHECKING: 

30 from mlos_bench.environments.base_environment import Environment 

31 from mlos_bench.optimizers.base_optimizer import Optimizer 

32 from mlos_bench.schedulers.base_scheduler import Scheduler 

33 from mlos_bench.services.base_service import Service 

34 from mlos_bench.storage.base_storage import Storage 

35 

36# BaseTypeVar is a generic with a constraint of the three base classes. 

37BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") 

38BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"] 

39 

40 

41def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict: 

42 """ 

43 Replaces all $name values in the destination config with the corresponding 

44 value from the source config. 

45 

46 Parameters 

47 ---------- 

48 dest : dict 

49 Destination config. 

50 source : Optional[dict] 

51 Source config. 

52 

53 Returns 

54 ------- 

55 dest : dict 

56 A reference to the destination config after the preprocessing. 

57 """ 

58 if source is None: 

59 source = {} 

60 for key, val in dest.items(): 

61 if isinstance(val, str) and val.startswith("$") and val[1:] in source: 

62 dest[key] = source[val[1:]] 

63 return dest 

64 

65 

66def merge_parameters(*, dest: dict, source: Optional[dict] = None, 

67 required_keys: Optional[Iterable[str]] = None) -> dict: 

68 """ 

69 Merge the source config dict into the destination config. 

70 Pick from the source configs *ONLY* the keys that are already present 

71 in the destination config. 

72 

73 Parameters 

74 ---------- 

75 dest : dict 

76 Destination config. 

77 source : Optional[dict] 

78 Source config. 

79 required_keys : Optional[Iterable[str]] 

80 An optional list of keys that must be present in the destination config. 

81 

82 Returns 

83 ------- 

84 dest : dict 

85 A reference to the destination config after the merge. 

86 """ 

87 if source is None: 

88 source = {} 

89 

90 for key in set(dest).intersection(source): 

91 dest[key] = source[key] 

92 

93 for key in required_keys or []: 

94 if key in dest: 

95 continue 

96 if key in source: 

97 dest[key] = source[key] 

98 else: 

99 raise ValueError("Missing required parameter: " + key) 

100 

101 return dest 

102 

103 

104def path_join(*args: str, abs_path: bool = False) -> str: 

105 """ 

106 Joins the path components and normalizes the path. 

107 

108 Parameters 

109 ---------- 

110 args : str 

111 Path components. 

112 

113 abs_path : bool 

114 If True, the path is converted to be absolute. 

115 

116 Returns 

117 ------- 

118 str 

119 Joined path. 

120 """ 

121 path = os.path.join(*args) 

122 if abs_path: 

123 path = os.path.abspath(path) 

124 return os.path.normpath(path).replace("\\", "/") 

125 

126 

127def prepare_class_load(config: dict, 

128 global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: 

129 """ 

130 Extract the class instantiation parameters from the configuration. 

131 

132 Parameters 

133 ---------- 

134 config : dict 

135 Configuration of the optimizer. 

136 global_config : dict 

137 Global configuration parameters (optional). 

138 

139 Returns 

140 ------- 

141 (class_name, class_config) : (str, dict) 

142 Name of the class to instantiate and its configuration. 

143 """ 

144 class_name = config["class"] 

145 class_config = config.setdefault("config", {}) 

146 

147 merge_parameters(dest=class_config, source=global_config) 

148 

149 if _LOG.isEnabledFor(logging.DEBUG): 

150 _LOG.debug("Instantiating: %s with config:\n%s", 

151 class_name, json.dumps(class_config, indent=2)) 

152 

153 return (class_name, class_config) 

154 

155 

156def get_class_from_name(class_name: str) -> type: 

157 """ 

158 Gets the class from the fully qualified name. 

159 

160 Parameters 

161 ---------- 

162 class_name : str 

163 Fully qualified class name. 

164 

165 Returns 

166 ------- 

167 type 

168 Class object. 

169 """ 

170 # We need to import mlos_bench to make the factory methods work. 

171 class_name_split = class_name.split(".") 

172 module_name = ".".join(class_name_split[:-1]) 

173 class_id = class_name_split[-1] 

174 

175 module = importlib.import_module(module_name) 

176 cls = getattr(module, class_id) 

177 assert isinstance(cls, type) 

178 return cls 

179 

180 

181# FIXME: Technically, this should return a type "class_name" derived from "base_class". 

182def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, 

183 *args: Any, **kwargs: Any) -> BaseTypeVar: 

184 """ 

185 Factory method for a new class instantiated from config. 

186 

187 Parameters 

188 ---------- 

189 base_class : type 

190 Base type of the class to instantiate. 

191 Currently it's one of {Environment, Service, Optimizer}. 

192 class_name : str 

193 FQN of a Python class to instantiate, e.g., 

194 "mlos_bench.environments.remote.HostEnv". 

195 Must be derived from the `base_class`. 

196 args : list 

197 Positional arguments to pass to the constructor. 

198 kwargs : dict 

199 Keyword arguments to pass to the constructor. 

200 

201 Returns 

202 ------- 

203 inst : Union[Environment, Service, Optimizer, Storage] 

204 An instance of the `class_name` class. 

205 """ 

206 impl = get_class_from_name(class_name) 

207 _LOG.info("Instantiating: %s :: %s", class_name, impl) 

208 

209 assert issubclass(impl, base_class) 

210 ret: BaseTypeVar = impl(*args, **kwargs) 

211 assert isinstance(ret, base_class) 

212 return ret 

213 

214 

215def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: 

216 """ 

217 Check if all required parameters are present in the configuration. 

218 Raise ValueError if any of the parameters are missing. 

219 

220 Parameters 

221 ---------- 

222 config : dict 

223 Free-format dictionary with the configuration 

224 of the service or benchmarking environment. 

225 required_params : Iterable[str] 

226 A collection of identifiers of the parameters that must be present 

227 in the configuration. 

228 """ 

229 missing_params = set(required_params).difference(config) 

230 if missing_params: 

231 raise ValueError( 

232 "The following parameters must be provided in the configuration" 

233 + f" or as command line arguments: {missing_params}") 

234 

235 

236def get_git_info(path: str = __file__) -> Tuple[str, str, str]: 

237 """ 

238 Get the git repository, commit hash, and local path of the given file. 

239 

240 Parameters 

241 ---------- 

242 path : str 

243 Path to the file in git repository. 

244 

245 Returns 

246 ------- 

247 (git_repo, git_commit, git_path) : Tuple[str, str, str] 

248 Git repository URL, last commit hash, and relative file path. 

249 """ 

250 dirname = os.path.dirname(path) 

251 git_repo = subprocess.check_output( 

252 ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() 

253 git_commit = subprocess.check_output( 

254 ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() 

255 git_root = subprocess.check_output( 

256 ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() 

257 _LOG.debug("Current git branch: %s %s", git_repo, git_commit) 

258 rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) 

259 return (git_repo, git_commit, rel_path.replace("\\", "/")) 

260 

261 

262# Note: to avoid circular imports, we don't specify TunableValue here. 

263def try_parse_val(val: Optional[str]) -> Optional[Union[int, float, str]]: 

264 """ 

265 Try to parse the value as an int or float, otherwise return the string. 

266 

267 This can help with config schema validation to make sure early on that 

268 the args we're expecting are the right type. 

269 

270 Parameters 

271 ---------- 

272 val : str 

273 The initial cmd line arg value. 

274 

275 Returns 

276 ------- 

277 TunableValue 

278 The parsed value. 

279 """ 

280 if val is None: 

281 return val 

282 try: 

283 val_float = float(val) 

284 try: 

285 val_int = int(val) 

286 return val_int if val_int == val_float else val_float 

287 except (ValueError, OverflowError): 

288 return val_float 

289 except ValueError: 

290 return str(val) 

291 

292 

293def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]: 

294 """ 

295 Poor man's Maybe monad: apply the function to the value if it's not None. 

296 

297 Parameters 

298 ---------- 

299 func : Callable 

300 Function to apply to the value. 

301 value : Optional[Any] 

302 Value to apply the function to. 

303 

304 Returns 

305 ------- 

306 value : Optional[Any] 

307 The result of the function application or None if the value is None. 

308 """ 

309 return None if value is None else func(value) 

310 

311 

312def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: 

313 """ 

314 Augment a timestamp with zoneinfo if missing and convert it to UTC. 

315 

316 Parameters 

317 ---------- 

318 timestamp : datetime 

319 A timestamp to convert to UTC. 

320 Note: The original datetime may or may not have tzinfo associated with it. 

321 

322 origin : Literal["utc", "local"] 

323 Whether the source timestamp is considered to be in UTC or local time. 

324 In the case of loading data from storage, where we intentionally convert all 

325 timestamps to UTC, this can help us retrieve the original timezone when the 

326 storage backend doesn't explicitly store it. 

327 In the case of receiving data from a client or other source, this can help us 

328 convert the timestamp to UTC if it's not already. 

329 

330 Returns 

331 ------- 

332 datetime 

333 A datetime with zoneinfo in UTC. 

334 """ 

335 if timestamp.tzinfo is not None or origin == "local": 

336 # A timestamp with no zoneinfo is interpretted as "local" time 

337 # (e.g., according to the TZ environment variable). 

338 # That could be UTC or some other timezone, but either way we convert it to 

339 # be explicitly UTC with zone info. 

340 return timestamp.astimezone(pytz.UTC) 

341 elif origin == "utc": 

342 # If the timestamp is already in UTC, we just add the zoneinfo without conversion. 

343 # Converting with astimezone() when the local time is *not* UTC would cause 

344 # a timestamp conversion which we don't want. 

345 return timestamp.replace(tzinfo=pytz.UTC) 

346 else: 

347 raise ValueError(f"Invalid origin: {origin}") 

348 

349 

350def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: 

351 """ 

352 A nullable version of utcify_timestamp. 

353 """ 

354 return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None 

355 

356 

357# All timestamps in the telemetry data must be greater than this date 

358# (a very rough approximation for the start of this feature). 

359_MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) 

360 

361 

362def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: 

363 """ 

364 Attempt to convert a pandas column to a datetime format. 

365 

366 Parameters 

367 ---------- 

368 datetime_col : pandas.Series 

369 The column to convert. 

370 

371 origin : Literal["utc", "local"] 

372 Whether to interpret naive timestamps as originating from UTC or local time. 

373 

374 Returns 

375 ------- 

376 pandas.Series 

377 The converted datetime column. 

378 

379 Raises 

380 ------ 

381 ValueError 

382 On parse errors. 

383 """ 

384 new_datetime_col = pandas.to_datetime(datetime_col, utc=False) 

385 # If timezone data is missing, assume the provided origin timezone. 

386 if new_datetime_col.dt.tz is None: 

387 if origin == "local": 

388 tzinfo = datetime.now().astimezone().tzinfo 

389 elif origin == "utc": 

390 tzinfo = pytz.UTC 

391 else: 

392 raise ValueError(f"Invalid timezone origin: {origin}") 

393 new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) 

394 assert new_datetime_col.dt.tz is not None 

395 # And convert it to UTC. 

396 new_datetime_col = new_datetime_col.dt.tz_convert('UTC') 

397 if new_datetime_col.isna().any(): 

398 raise ValueError(f"Invalid date format in the data: {datetime_col}") 

399 if new_datetime_col.le(_MIN_TS).any(): 

400 raise ValueError(f"Invalid date range in the data: {datetime_col}") 

401 return new_datetime_col