# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""
Create a week-by-week heatmap of a selected Viva Insights metric.
The `create_trend` function provides a week by week view of a selected Viva Insights metric,
allowing you to either return a week by week heatmap bar plot or a summary table.
By default, `create_trend` returns a week by week heatmap bar plot, highlighting the points intime with most activity.
Additional options available to return a summary table.
"""
__all__ = ['create_trend', 'create_trend_calc', 'create_trend_viz']
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from vivainsights.extract_date_range import extract_date_range
from vivainsights import totals_col
[docs]
def create_trend(data: pd.DataFrame,
metric: str,
palette = [
"#0c3c44",
"#1d627e",
"#34b1e2",
"#bfe5ee",
"#fcf0eb",
"#fbdacd",
"#facebc",
"#fe7f4f"
],
hrvar: str = "Organization",
mingroup = 5,
return_type: str = "plot",
legend_title: str = "Hours",
date_column: str = "MetricDate",
date_format: str = "%Y-%m-%d",
figsize: tuple = None,
size_x_axis_label: int = 5
):
"""Create a week-by-week heatmap of a selected metric.
Produces a heatmap bar plot highlighting activity hotspots over time,
or returns a summary table.
Parameters
----------
data : pandas.DataFrame
Person query data.
metric : str
Name of the metric column to plot.
palette : list of str
Colours used for the heatmap gradient.
hrvar : str, default "Organization"
Name of the organizational attribute for grouping.
mingroup : int, default 5
Minimum group size.
return_type : str, default "plot"
``"plot"`` for a heatmap figure, ``"table"`` for a pivoted DataFrame.
legend_title : str, default "Hours"
Label for the colour-bar legend.
date_column : str, default "MetricDate"
Name of the date column.
date_format : str, default "%Y-%m-%d"
``strftime`` format of dates in *date_column*.
figsize : tuple or None, default None
Figure size ``(width, height)`` in inches. Defaults to ``(8, 6)``.
size_x_axis_label : int, default 5
Font size for x-axis bracket labels.
Returns
-------
matplotlib.figure.Figure or pandas.DataFrame
A heatmap figure or a pivoted summary table.
Examples
--------
Return a heatmap (default):
>>> import vivainsights as vi
>>> pq_data = vi.load_pq_data()
>>> vi.create_trend(pq_data, metric="Collaboration_hours", hrvar="LevelDesignation")
Return a pivoted summary table:
>>> vi.create_trend(pq_data, metric="Collaboration_hours", hrvar="LevelDesignation", return_type="table")
Customize the legend title, figure size, and date format:
>>> vi.create_trend(
... pq_data,
... metric="Collaboration_hours",
... hrvar="Organization",
... legend_title="Avg Hours",
... figsize=(12, 6),
... date_format="%Y-%m-%d",
... )
"""
if(hrvar is None):
data = totals_col(data)
hrvar = "Total"
# Return the table or the plot or raise an error
if return_type == "table":
myTable = create_trend_calc(data, metric, hrvar, mingroup, date_column, date_format)
myTable_return = myTable.pivot(index="group", columns=date_column, values=metric)
return myTable_return
elif return_type == "plot":
return create_trend_viz(data, metric, palette, hrvar, mingroup, legend_title, date_column, date_format, size_x_axis_label, figsize)
else:
raise ValueError("Please enter a valid input for return_type.")
[docs]
def create_trend_calc(data, metric, hrvar, mingroup, date_column, date_format):
"""
Compute weekly group-level metric averages for trend analysis.
Used internally by ``create_trend``.
Parameters
----------
data : pandas.DataFrame
Person query data.
metric : str
Name of the metric column.
hrvar : str
Name of the organizational attribute for grouping.
mingroup : int
Minimum group size.
date_column : str
Name of the date column.
date_format : str
``strftime`` format of dates in *date_column*.
Returns
-------
pandas.DataFrame
Aggregated table with date, group, employee count, and metric mean.
"""
# Check inputs
required_variables = [date_column, metric, "PersonId"]
# Error message if variables are not present
# Nothing happens if all present
for var in required_variables:
if var not in data.columns:
raise ValueError(f"{var} is not in the data")
# Clean metric name
clean_nm = metric.replace("_", " ")
# Convert Date to datetime and rename hrvar to group
data = data.copy()
data[date_column] = pd.to_datetime(data[date_column], format=date_format)
data = data.rename(columns={hrvar: "group"})
# Select relevant columns
myTable = data[["PersonId", date_column, "group", metric]]
# Determine eligible groups based on overall group size across the dataset
group_sizes = (
myTable.groupby("group")["PersonId"].nunique().reset_index(name="Group_Count")
)
eligible_groups = set(group_sizes.loc[group_sizes["Group_Count"] >= mingroup, "group"])
# Filter table to only eligible groups (based on overall size)
myTable = myTable[myTable["group"].isin(eligible_groups)]
# Compute Employee_Count per date and group (unique people that week in that group)
agg = (
myTable
.groupby(["group", date_column])
.agg(Employee_Count=("PersonId", "nunique"), **{metric: (metric, "mean")})
.reset_index()
)
# Reorder columns to match expected output
agg = agg[[date_column, "group", "Employee_Count", metric]]
return agg
[docs]
def create_trend_viz(
data: pd.DataFrame,
metric: str,
palette,
hrvar: str,
mingroup,
legend_title: str,
date_column: str,
date_format: str,
size_x_axis_label,
figsize: tuple = None
):
"""
Create a heatmap visualization of a metric over time by group.
Used internally by ``create_trend`` when ``return_type="plot"``.
Parameters
----------
data : pandas.DataFrame
Person query data.
metric : str
Name of the metric column.
palette : list of str
Colours for the heatmap gradient.
hrvar : str
Name of the organizational attribute for grouping.
mingroup : int
Minimum group size.
legend_title : str
Label for the colour-bar legend.
date_column : str
Name of the date column.
date_format : str
``strftime`` format of dates in *date_column*.
size_x_axis_label : int
Font size for x-axis bracket labels.
figsize : tuple or None, default None
Figure size ``(width, height)`` in inches.
Returns
-------
matplotlib.figure.Figure
The heatmap figure.
"""
myTable = create_trend_calc(data, metric, hrvar, mingroup, date_column, date_format)
myTable_plot = myTable[[date_column, "group", metric]]
# Cleaning labels for plotting
clean_nm = metric.replace("_", " ")
title_text = f"{clean_nm} Hotspots"
subtitle_text = f'By {hrvar}'
caption_text = extract_date_range(data, return_type = 'text')
# Creating the plot object
fig, ax = plt.subplots(figsize=figsize if figsize else (8, 6))
# Removing tick marks
ax.tick_params(
which='both', # Both major and minor ticks are affected
top=False, # Remove ticks from the top
bottom=False, # Remove ticks from the bottom
left=False, # Remove ticks from the left
right=False # Remove ticks from the right
)
# Creating Pivot the data and sort columns
pivot_table = myTable_plot.pivot(index="group", columns=date_column, values=metric)
pivot_table = pivot_table.sort_index(axis=1)
# Creating heatmap
sns.heatmap(
data=pivot_table,
cmap=palette,
cbar_kws={"label": legend_title},
xticklabels=False
)
# Calculating date range and generate tick labels
date_range_days = (pivot_table.columns.max() - pivot_table.columns.min()).days
date_list = list(pivot_table.columns)
# Deciding format and deduplicate
tick_labels = []
last_label = ""
for d in date_list:
if date_range_days > 365:
label = d.strftime('%Y')
elif date_range_days > 90:
label = d.strftime('%b %Y')
else:
label = d.strftime('%d-%m-%y')
if label != last_label:
tick_labels.append(label)
last_label = label
else:
tick_labels.append("") # Empty for duplicate to avoid clutter
# Explicitly setting the x-ticks positions and labels
ax.set_xticks([i + 0.5 for i in range(len(date_list))]) # heatmap cell centers
ax.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=9)
# Grouping indices by the bracket label
bracket_groups = {}
current_label = None
for idx, label in enumerate(tick_labels):
if label != "":
current_label = label
bracket_groups[current_label] = [idx, idx]
else:
if current_label:
bracket_groups[current_label][1] = idx
# Drawing brackets clearly under each grouped label
bracket_y = -0.08
for label, (start, end) in bracket_groups.items():
ax.hlines(y=bracket_y, xmin=start, xmax=end + 1, color='black', linewidth=1.2, transform=ax.get_xaxis_transform(), clip_on=False)
ax.vlines([start, end + 1], ymin=bracket_y - 0.01, ymax=bracket_y, color='black', linewidth=1.2, transform=ax.get_xaxis_transform(), clip_on=False)
ax.text((start + end + 1) / 2, bracket_y - 0.03, label, ha='center', va='top', fontsize=size_x_axis_label, transform=ax.get_xaxis_transform(), clip_on=False)
# Adding padding at bottom for brackets and labels
plt.subplots_adjust(bottom=0.12)
# Set x-tick labels
# Reformat x-axis tick labels
ax.xaxis.set_tick_params(labelsize = 9, rotation=45)
ax.yaxis.set_tick_params(labelsize = 9)
# ax.xaxis.set_major_formatter(mdates.DateFormatter('%d %b %y'))
# Remove axis labels
ax.set_xlabel('')
ax.set_ylabel('')
ax.plot(
[-0.08, .9], # Set width of line, previously [-0.08, .9]
[0.9, 0.9], # Set height of line
transform = fig.transFigure, # Set location relative to plot
clip_on = False,
color = '#fe7f4f',
linewidth = .6
)
ax.add_patch(
plt.Rectangle(
(-0.08, 0.9), # Set location of rectangle by lower left corner, previously [-0.08, .9]
0.05, # Width of rectangle
-0.025, # Height of rectangle
facecolor = '#fe7f4f',
transform = fig.transFigure,
clip_on = False,
linewidth = 0
)
)
# Set title
ax.text(
x = -0.08, y = 1.00,
s = title_text,
transform = fig.transFigure,
ha = 'left',
fontsize = 13,
weight = 'bold',
alpha = .8
)
# Set subtitle
ax.text(
x = -0.08, y = 0.95,
s = subtitle_text,
transform = fig.transFigure,
ha = 'left',
fontsize = 11,
alpha = .8
)
ax.xaxis.set_major_locator(plt.NullLocator())
ax.xaxis.set_major_formatter(plt.NullFormatter())
# Set caption
ax.text(x=-0.08, y=-0.12, s=caption_text, transform=fig.transFigure, ha='left', fontsize=9, alpha=.7)
# return the plot object
return fig