{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Automatic Gaussian Mixture Modeling" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import seaborn as sns\n", "sns.set(font_scale=1.75)\n", "sns.set_style(\"white\")\n", "\n", "import random\n", "np.random.seed(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clustering is a foundational data analysis task, where members of the data set are sorted into groups or \"clusters\" according to measured similarities between the objects.\n", "\n", "The Automatic Gaussian Mixture Model (AutoGMM) is a wrapper of [Sklearn's Gaussian Mixture class](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture). Different combinations of agglomeration, GMM, and cluster numbers are used in the algorithm, and the clustering with the best selection criterion, either Bayesian Information Criterion (BIC) or Akaike Information Criterion (AIC), is provided to the user.\n", "\n", "Let's use AutoGMM on synthetic data and compare it to the existing Sklearn implementation." ] }, { "source": [ "## Using AutoGMM on Synthetic Data" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Synthetic data\n", "\n", "# Dim 1\n", "class_1 = np.random.randn(150, 1)\n", "class_2 = 2 + np.random.randn(150, 1)\n", "dim_1 = np.vstack((class_1, class_2))\n", "\n", "# Dim 2\n", "class_1 = np.random.randn(150, 1)\n", "class_2 = 2 + np.random.randn(150, 1)\n", "dim_2 = np.vstack((class_1, class_2))\n", "\n", "X = np.hstack((dim_1, dim_2))\n", "\n", "# Labels\n", "label_1 = np.zeros((150, 1))\n", "label_2 = 1 + label_1\n", "\n", "c = np.vstack((label_1, label_2)).reshape(300,)\n", "\n", "# Plotting Function for Clustering\n", "def plot(title, c_hat, X):\n", " plt.figure(figsize=(10, 10))\n", " n_components = int(np.max(c_hat) + 1)\n", " palette = sns.color_palette(\"deep\")[:n_components]\n", " fig = sns.scatterplot(x=X[:,0], y=X[:,1], hue=c_hat, legend=None, palette=palette)\n", " fig.set(xticks=[], yticks=[], title=title)\n", " plt.show()\n", "\n", "plot('True Clustering', c, X)" ] }, { "source": [ "In the existing Sklearn implementation, one has to choose model parameters apriori, including the number of components. If parameters are input that don't match the data well, clustering performance can suffer. Performance can be measured by ARI, a metric ranging from 0 to 1. An ARI score of 1 indicates the estimated clusters are identical to the true clusters." ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import mixture\n", "from sklearn.metrics import confusion_matrix\n", "from scipy.optimize import linear_sum_assignment\n", "from sklearn.metrics import adjusted_rand_score\n", "\n", "from graspologic.utils import remap_labels\n", "\n", "# Say user provides inaccurate estimate of number of components\n", "gmm_ = mixture.GaussianMixture(3)\n", "c_hat_gmm = gmm_.fit_predict(X)\n", "\n", "# Remap Predicted labels\n", "c_hat_gmm = remap_labels(c, c_hat_gmm)\n", "\n", "plot('Sklearn Clustering', c_hat_gmm, X)\n", "\n", "# ARI Score\n", "print(\"ARI Score for Model: %.2f\" % adjusted_rand_score(c, c_hat_gmm))" ] }, { "source": [ "Our method expands upon the existing Sklearn framework by allowing the user to automatically estimate the best hyperparameters for a Gaussian mixture model. In particular, the ideal `n_components_` is estimated by AutoGMM from a range of possible values given by the user. AutoGMM also sweeps over multiple covariance structures." ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from graspologic.cluster.autogmm import AutoGMMCluster\n", "\n", "# Fit AutoGMM model\n", "autogmm_ = AutoGMMCluster(min_components=1, max_components=10)\n", "c_hat_autogmm = autogmm_.fit_predict(X)\n", "\n", "c_hat_autogmm = remap_labels(c, c_hat_autogmm)\n", "\n", "plot('AutoGMM Clustering', c_hat_autogmm, X)\n", "\n", "print(\"ARI Score for Model: %.2f\" % adjusted_rand_score(c, c_hat_autogmm))" ] } ], "metadata": { "language_info": { "name": "python", "codemirror_mode": { "name": "ipython", "version": 3 }, "version": "3.7.9-final" }, "orig_nbformat": 2, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "nbformat": 4, "nbformat_minor": 2 }