Coverage for mlos_viz/mlos_viz/__init__.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 01:18 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6mlos_viz is a framework to help visualizing, explain, and gain insights from results 

7from the :py:mod:`mlos_bench` framework for benchmarking and optimization automation. 

8 

9Its main entrypoint is the :py:func:`plot` function, which can be used to 

10automatically visualize :py:class:`~.ExperimentData` from :py:mod:`mlos_bench` using 

11other libraries for automatic data correlation and visualization like 

12:external:py:func:`dabl <dabl.plot>`. 

13""" 

14 

15from enum import Enum 

16from typing import Any, Dict, Literal, Optional 

17 

18import pandas 

19 

20from mlos_bench.storage.base_experiment_data import ExperimentData 

21from mlos_viz import base 

22from mlos_viz.util import expand_results_data_args 

23from mlos_viz.version import VERSION 

24 

25__version__ = VERSION 

26 

27 

28class MlosVizMethod(Enum): 

29 """What method to use for visualizing the experiment results.""" 

30 

31 DABL = "dabl" 

32 AUTO = DABL # use dabl as the current default 

33 

34 

35def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: 

36 """ 

37 Suppress some annoying warnings from third-party data visualization packages by 

38 adding them to the warnings filter. 

39 

40 Parameters 

41 ---------- 

42 plotter_method: MlosVizMethod 

43 The method to use for visualizing the experiment results. 

44 """ 

45 base.ignore_plotter_warnings() 

46 if plotter_method == MlosVizMethod.DABL: 

47 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

48 

49 mlos_viz.dabl.ignore_plotter_warnings() 

50 else: 

51 raise NotImplementedError(f"Unhandled method: {plotter_method}") 

52 

53 

54def plot( 

55 exp_data: Optional[ExperimentData] = None, 

56 *, 

57 results_df: Optional[pandas.DataFrame] = None, 

58 objectives: Optional[Dict[str, Literal["min", "max"]]] = None, 

59 plotter_method: MlosVizMethod = MlosVizMethod.AUTO, 

60 filter_warnings: bool = True, 

61 **kwargs: Any, 

62) -> None: 

63 """ 

64 Plots the results of the experiment. 

65 

66 Intended to be used from a Jupyter notebook. 

67 

68 Parameters 

69 ---------- 

70 exp_data: ExperimentData 

71 The experiment data to plot. 

72 results_df : Optional[pandas.DataFrame] 

73 Optional `results_df` to plot. 

74 If not provided, defaults to :py:attr:`.ExperimentData.results_df` property. 

75 objectives : Optional[Dict[str, Literal["min", "max"]]] 

76 Optional objectives to plot. 

77 If not provided, defaults to :py:attr:`.ExperimentData.objectives` property. 

78 plotter_method: MlosVizMethod 

79 The method to use for visualizing the experiment results. 

80 filter_warnings: bool 

81 Whether or not to filter some warnings from the plotter. 

82 kwargs : dict 

83 Remaining keyword arguments are passed along to the underlying plotter(s). 

84 """ 

85 if filter_warnings: 

86 ignore_plotter_warnings(plotter_method) 

87 (results_df, _obj_cols) = expand_results_data_args(exp_data, results_df, objectives) 

88 

89 base.plot_optimizer_trends(exp_data, results_df=results_df, objectives=objectives) 

90 base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) 

91 

92 if MlosVizMethod.DABL: 

93 import mlos_viz.dabl # pylint: disable=import-outside-toplevel 

94 

95 mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) 

96 else: 

97 raise NotImplementedError(f"Unhandled method: {plotter_method}")