Source code for archai.supergraph.utils.heatmap
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Optional
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
[docs]def heatmap(
data: np.array,
ax: Optional[Axes] = None,
xtick_labels: Optional[List[str]] = None,
ytick_labels: Optional[List[str]] = None,
cbar_kwargs: Optional[Dict[str, Any]] = None,
cbar_label: Optional[str] = None,
fmt: Optional[str] = "{x:.2f}",
**kwargs,
) -> None:
"""Plot a heatmap.
Args:
data: Data to plot.
ax: Axis to plot on.
xtick_labels: Labels for the x-axis.
ytick_labels: Labels for the y-axis.
cbar_kwargs: Keyword arguments to pass to the color bar.
cbar_label: Label for the color bar.
fmt: Format of the annotations.
"""
# Create the axis and plot the heatmap
if ax is None:
ax = plt.gca()
im = ax.imshow(data, **kwargs)
# Create the color bar
if cbar_kwargs is None:
cbar_kwargs = {}
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kwargs)
cbar.ax.set_ylabel(cbar_label, rotation=-90, va="bottom")
# Display all ticks
if xtick_labels is None:
xtick_labels = [i for i in range(data.shape[1])]
ax.set_xticks(np.arange(data.shape[1]), labels=xtick_labels)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
if ytick_labels is None:
ytick_labels = [i for i in range(data.shape[0])]
ax.set_yticks(np.arange(data.shape[0]), labels=ytick_labels)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
# Adjust the grid layout and ticks positioning
ax.spines[:].set_visible(False)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", top=False, bottom=False, left=False, labeltop=True, labelbottom=False)
# Annotate the heatmap
if isinstance(fmt, str):
fmt = matplotlib.ticker.StrMethodFormatter(fmt)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
im.axes.text(j, i, fmt(data[i, j], None), horizontalalignment="center", verticalalignment="center")