Source code for vivainsights.create_sankey

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

#Create a sankey chart from a two-column count table

import pandas as pd
import plotly.graph_objects as go
import random
[docs] def create_sankey(data, var1, var2, count = "n"): """ Name ---- create_sankey Description ------------ Create a 'networkD3' style sankey chart based on a long count table with two variables. The input data should have three columns, where each row is a unique group: 1. Variable 1 2. Variable 2 3. Count Parameters ---------- data : dataframe Data frame of the long count table. var1 : str String containing the name of the variable to be shown on the left. var2 : str String containing the name of the variable to be shown on the right. count : str String containing the name of the count variable. Returns ------- A 'sankeyNetwork' and 'htmlwidget' object containing a two-tier sankey plot. The output can be saved locally with `htmlwidgets::saveWidget()`. Example ------- >>> create_sankey(data = pq_data, var1 = "Organization", var2 = "FunctionType") """ #Rename data['pregroup'] = data[[var1]] data['group'] = data[[var2]] #Set up nodes group_source = data['pregroup'].unique() group_target = data['group'].unique() + " " nodes_source = pd.DataFrame({'name': group_source}) nodes_target = pd.DataFrame({'name': group_target}) nodes = pd.concat([nodes_source, nodes_target], axis=0) nodes = nodes.reset_index(drop=True) nodes["node"] = range(len(nodes)) links = data.assign(group=data['group'] + " ") links = links.rename(columns={'pregroup': 'source', 'group': 'target', count: 'value'}) links = links.loc[:, ['source', 'target', 'value']] sources = links['source'].unique() targets = links['target'].unique() source_colours = [] target_colours = [] for i in range(len(sources)): source_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}") for i in range(len(targets)): target_colours.append(f"#{random.randint(0, 0xFFFFFF):06x}") #left join # links = pd.merge(links, nodes_source, left_on='source', right_on='name', how='left') # links = pd.merge(links, nodes_target, left_on='target', right_on='name', how='left') '''Sankey diagram''' fig = go.Figure(data=[go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=nodes['name'], color = source_colours + target_colours ), link=dict( source=links['source'].map(nodes.set_index('name')['node']), target=links['target'].map(nodes.set_index('name')['node']), value=links['value'], ))] ) fig.update_layout(title_text="Sankey Diagram of " + var1 + " and " + var2, font_size=10) fig.show()