# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""
Create a bubble chart visualization of two metrics by organizational group.
The function `create_bubble` creates a bubble visualization and summary table for a given metric
and grouping variable in a dataset.
"""
__all__ = ['create_bubble']
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_hex
from adjustText import adjust_text
from vivainsights.totals_col import totals_col
from matplotlib.lines import Line2D
[docs]
def create_bubble(data, metric_x, metric_y, hrvar="Organization", mingroup=5, return_type="plot", bubble_size=(1, 100), figsize: tuple = None):
"""
Create a bubble plot of two metrics by organizational group.
Metrics are first aggregated per person, then per HR variable group.
Bubble size represents the number of employees in each group.
Parameters
----------
data : pandas.DataFrame
Person query data.
metric_x : str
Column name for the x-axis metric.
metric_y : str
Column name for the y-axis metric.
hrvar : str, optional
Organizational attribute for grouping. Defaults to ``"Organization"``.
mingroup : int, optional
Minimum group size. Groups below this are excluded. Defaults to 5.
return_type : str, optional
``"plot"`` (default) returns a bubble chart; ``"table"`` returns a
summary DataFrame.
bubble_size : tuple, optional
``(min_size, max_size)`` range for bubble scaling. Defaults to ``(1, 100)``.
figsize : tuple, optional
Figure size as ``(width, height)`` in inches. Defaults to ``(8, 6)``.
Returns
-------
matplotlib.figure.Figure or pandas.DataFrame
Bubble chart or summary table depending on ``return_type``.
Examples
--------
Return a bubble plot (default):
>>> import vivainsights as vi
>>> pq_data = vi.load_pq_data()
>>> vi.create_bubble(
... data=pq_data,
... metric_x="Collaboration_hours",
... metric_y="Multitasking_hours",
... hrvar="Organization",
... )
Return a summary table:
>>> vi.create_bubble(
... data=pq_data,
... metric_x="Collaboration_hours",
... metric_y="Multitasking_hours",
... hrvar="LevelDesignation",
... return_type="table",
... )
Customize bubble size range, minimum group size, and figure size:
>>> vi.create_bubble(
... data=pq_data,
... metric_x="Collaboration_hours",
... metric_y="Multitasking_hours",
... hrvar="Organization",
... bubble_size=(5, 200),
... mingroup=10,
... figsize=(12, 8),
... )
"""
# Handling NULL values passed to hrvar
if(hrvar is None):
data = totals_col(data)
hrvar = "Total"
col_highlight = "#fe7f4f"
# Input checks
required_variables = [hrvar, metric_x, metric_y, "PersonId"]
for var in required_variables:
if var not in data.columns:
raise ValueError(f"Missing required variable: {var}")
# Clean metric names
clean_x = metric_x.replace('_', ' ')
clean_y = metric_y.replace('_', ' ')
# Group and summarize data
myTable = data.groupby(['PersonId', hrvar]).agg({metric_x: 'mean', metric_y: 'mean'}).reset_index()
myTable = myTable.groupby(hrvar).agg({metric_x: 'mean', metric_y: 'mean', 'PersonId': 'count'}).reset_index()
myTable = myTable.rename(columns={'PersonId': 'n'})
myTable = myTable[myTable['n'] >= mingroup]
# Plotting
if return_type == "plot":
fig, ax = plt.subplots(figsize=figsize if figsize else (8, 6))
# Reserve more space for title/subtitle/orange line
plt.subplots_adjust(top=0.82)
# Scatterplot
sns.scatterplot(data=myTable, x=metric_x, y=metric_y, size='n', sizes=bubble_size,
alpha=0.5, color=to_hex((0, 120/255, 212/255)), ax=ax)
# Bubble labels
texts = [ax.text(row[metric_x], row[metric_y], row[hrvar], size=8) for _, row in myTable.iterrows()]
adjust_text(texts, ax=ax)
# Title and subtitle using fig.text (not ax.text) for exact positioning
fig.text(0.1, 0.95, f"{clean_x} and {clean_y}", ha='left', fontsize=14, weight='bold', alpha=0.9)
fig.text(0.1, 0.91, f"By {hrvar.replace('_', ' ')}", ha='left', fontsize=12, alpha=0.85)
# Orange decorative line (below subtitle)
fig.lines.append(
Line2D(
[0.1, 0.9], [0.89, 0.89], # y = just below subtitle
transform=fig.transFigure,
color=col_highlight,
linewidth=0.6,
clip_on=False
)
)
# Orange rectangle block
fig.patches.extend([
plt.Rectangle(
(0.1, 0.89), 0.05, -0.015,
facecolor=col_highlight,
transform=fig.transFigure,
clip_on=False,
linewidth=0
)
])
# Axes labels
ax.set_xlabel(clean_x)
ax.set_ylabel(clean_y)
# Caption
fig.text(0.1, 0.02, f"Total employees = {myTable['n'].sum()} | {pd.to_datetime(data['MetricDate']).min().strftime('%Y-%m-%d')} to {pd.to_datetime(data['MetricDate']).max().strftime('%Y-%m-%d')}", fontsize=8)
return fig
elif return_type == "table":
return myTable
else:
raise ValueError("Please enter a valid input for `return_type`.")