Source code for vivainsights.create_sankey

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
"""
Create a Sankey chart from a two-column count table.
"""

__all__ = ['create_sankey']

import pandas as pd
import plotly.graph_objects as go
import random
[docs] def create_sankey(data, var1, var2, count = "n"): """Create a Sankey diagram from a long count table. The input *data* should have at least three columns: two categorical variables and a count column, where each row represents a unique group combination. Parameters ---------- data : pandas.DataFrame Long count table. var1 : str Column name for the variable shown on the left. var2 : str Column name for the variable shown on the right. count : str, default "n" Column name containing the count values. Returns ------- None Displays an interactive Plotly Sankey diagram. Examples -------- Create a Sankey diagram from a person query dataset: >>> import vivainsights as vi >>> pq_data = vi.load_pq_data() >>> agg = pq_data.groupby(["Organization", "FunctionType"]).agg(n=("PersonId", "nunique")).reset_index() >>> vi.create_sankey(data=agg, var1="Organization", var2="FunctionType", count="n") Use a custom count column: >>> vi.create_sankey(data=agg, var1="Organization", var2="FunctionType", count="n") """ #Rename data['pregroup'] = data[[var1]] data['group'] = data[[var2]] #Set up nodes group_source = data['pregroup'].unique() group_target = data['group'].unique() + " " nodes_source = pd.DataFrame({'name': group_source}) nodes_target = pd.DataFrame({'name': group_target}) nodes = pd.concat([nodes_source, nodes_target], axis=0) nodes = nodes.reset_index(drop=True) nodes["node"] = range(len(nodes)) links = data.assign(group=data['group'] + " ") links = links.rename(columns={'pregroup': 'source', 'group': 'target', count: 'value'}) links = links.loc[:, ['source', 'target', 'value']] sources = links['source'].unique() targets = links['target'].unique() source_colours = [] target_colours = [] for i in range(len(sources)): source_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}") for i in range(len(targets)): target_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}") #left join # links = pd.merge(links, nodes_source, left_on='source', right_on='name', how='left') # links = pd.merge(links, nodes_target, left_on='target', right_on='name', how='left') '''Sankey diagram''' fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=nodes['name'], color = source_colours + target_colours ), link=dict( source=links['source'].map(nodes.set_index('name')['node']), target=links['target'].map(nodes.set_index('name')['node']), value=links['value'], ))] ) fig.update_layout(title_text="Sankey Diagram of " + var1 + " and " + var2, font_size=10) fig.show()