Coverage for mlos_viz/mlos_viz/__init__.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6mlos_viz is a framework to help visualizing, explain, and gain insights from results 

7from the mlos_bench framework for benchmarking and optimization automation. 

8""" 

9 

10from enum import Enum 

11from typing import Any, Dict, Literal, Optional 

12 

13import pandas 

14 

15from mlos_bench.storage.base_experiment_data import ExperimentData 

16from mlos_viz import base 

17from mlos_viz.util import expand_results_data_args 

18 

19 

20class MlosVizMethod(Enum): 

21 """ 

22 What method to use for visualizing the experiment results. 

23 """ 

24 

25 DABL = "dabl" 

26 AUTO = DABL # use dabl as the current default 

27 

28 

29def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: 

30 """ 

31 Suppress some annoying warnings from third-party data visualization packages by 

32 adding them to the warnings filter. 

33 

34 Parameters 

35 ---------- 

36 plotter_method: MlosVizMethod 

37 The method to use for visualizing the experiment results. 

38 """ 

39 base.ignore_plotter_warnings() 

40 if plotter_method == MlosVizMethod.DABL: 

41 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

42 mlos_viz.dabl.ignore_plotter_warnings() 

43 else: 

44 raise NotImplementedError(f"Unhandled method: {plotter_method}") 

45 

46 

47def plot(exp_data: Optional[ExperimentData] = None, *, 

48 results_df: Optional[pandas.DataFrame] = None, 

49 objectives: Optional[Dict[str, Literal["min", "max"]]] = None, 

50 plotter_method: MlosVizMethod = MlosVizMethod.AUTO, 

51 filter_warnings: bool = True, 

52 **kwargs: Any) -> None: 

53 """ 

54 Plots the results of the experiment. 

55 

56 Intended to be used from a Jupyter notebook. 

57 

58 Parameters 

59 ---------- 

60 exp_data: ExperimentData 

61 The experiment data to plot. 

62 results_df : Optional["pandas.DataFrame"] 

63 Optional results_df to plot. 

64 If not provided, defaults to exp_data.results_df property. 

65 objectives : Optional[Dict[str, Literal["min", "max"]]] 

66 Optional objectives to plot. 

67 If not provided, defaults to exp_data.objectives property. 

68 plotter_method: MlosVizMethod 

69 The method to use for visualizing the experiment results. 

70 filter_warnings: bool 

71 Whether or not to filter some warnings from the plotter. 

72 kwargs : dict 

73 Remaining keyword arguments are passed along to the underlying plotter(s). 

74 """ 

75 if filter_warnings: 

76 ignore_plotter_warnings(plotter_method) 

77 (results_df, _obj_cols) = expand_results_data_args(exp_data, results_df, objectives) 

78 

79 base.plot_optimizer_trends(exp_data, results_df=results_df, objectives=objectives) 

80 base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) 

81 

82 if MlosVizMethod.DABL: 

83 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

84 mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) 

85 else: 

86 raise NotImplementedError(f"Unhandled method: {plotter_method}")