Source code for vivainsights.create_sankey
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""
Create a Sankey chart from a two-column count table.
"""
__all__ = ['create_sankey']
import pandas as pd
import plotly.graph_objects as go
import random
[docs]
def create_sankey(data, var1, var2, count = "n"):
"""Create a Sankey diagram from a long count table.
The input *data* should have at least three columns: two categorical
variables and a count column, where each row represents a unique
group combination.
Parameters
----------
data : pandas.DataFrame
Long count table.
var1 : str
Column name for the variable shown on the left.
var2 : str
Column name for the variable shown on the right.
count : str, default "n"
Column name containing the count values.
Returns
-------
None
Displays an interactive Plotly Sankey diagram.
Examples
--------
Create a Sankey diagram from a person query dataset:
>>> import vivainsights as vi
>>> pq_data = vi.load_pq_data()
>>> agg = pq_data.groupby(["Organization", "FunctionType"]).agg(n=("PersonId", "nunique")).reset_index()
>>> vi.create_sankey(data=agg, var1="Organization", var2="FunctionType", count="n")
Use a custom count column:
>>> vi.create_sankey(data=agg, var1="Organization", var2="FunctionType", count="n")
"""
#Rename
data['pregroup'] = data[[var1]]
data['group'] = data[[var2]]
#Set up nodes
group_source = data['pregroup'].unique()
group_target = data['group'].unique() + " "
nodes_source = pd.DataFrame({'name': group_source})
nodes_target = pd.DataFrame({'name': group_target})
nodes = pd.concat([nodes_source, nodes_target], axis=0)
nodes = nodes.reset_index(drop=True)
nodes["node"] = range(len(nodes))
links = data.assign(group=data['group'] + " ")
links = links.rename(columns={'pregroup': 'source', 'group': 'target', count: 'value'})
links = links.loc[:, ['source', 'target', 'value']]
sources = links['source'].unique()
targets = links['target'].unique()
source_colours = []
target_colours = []
for i in range(len(sources)):
source_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}")
for i in range(len(targets)):
target_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}")
#left join
# links = pd.merge(links, nodes_source, left_on='source', right_on='name', how='left')
# links = pd.merge(links, nodes_target, left_on='target', right_on='name', how='left')
'''Sankey diagram'''
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=nodes['name'],
color = source_colours + target_colours
),
link=dict(
source=links['source'].map(nodes.set_index('name')['node']),
target=links['target'].map(nodes.set_index('name')['node']),
value=links['value'],
))]
)
fig.update_layout(title_text="Sankey Diagram of " + var1 + " and " + var2, font_size=10)
fig.show()