{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# GMM covariances\n\nDemonstration of several covariances types for Gaussian mixture models.\n\nSee `gmm` for more information on the estimator.\n\nAlthough GMM are often used for clustering, we can compare the obtained\nclusters with the actual classes from the dataset. We initialize the means\nof the Gaussians with the means of the classes from the training set to make\nthis comparison valid.\n\nWe plot predicted labels on both training and held out test data using a\nvariety of GMM covariance types on the iris dataset.\nWe compare GMMs with spherical, diagonal, full, and tied covariance\nmatrices in increasing order of performance. Although one would\nexpect full covariance to perform best in general, it is prone to\noverfitting on small datasets and does not generalize well to held out\ntest data.\n\nOn the plots, train data is shown as dots, while test data is shown as\ncrosses. The iris dataset is four-dimensional. Only the first two\ndimensions are shown here, and thus some points are separated in other\ndimensions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import datasets\nfrom sklearn.mixture import GaussianMixture\nfrom sklearn.model_selection import StratifiedKFold\n\ncolors = [\"navy\", \"turquoise\", \"darkorange\"]\n\n\ndef make_ellipses(gmm, ax):\n for n, color in enumerate(colors):\n if gmm.covariance_type == \"full\":\n covariances = gmm.covariances_[n][:2, :2]\n elif gmm.covariance_type == \"tied\":\n covariances = gmm.covariances_[:2, :2]\n elif gmm.covariance_type == \"diag\":\n covariances = np.diag(gmm.covariances_[n][:2])\n elif gmm.covariance_type == \"spherical\":\n covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]\n v, w = np.linalg.eigh(covariances)\n u = w[0] / np.linalg.norm(w[0])\n angle = np.arctan2(u[1], u[0])\n angle = 180 * angle / np.pi # convert to degrees\n v = 2.0 * np.sqrt(2.0) * np.sqrt(v)\n ell = mpl.patches.Ellipse(\n gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color\n )\n ell.set_clip_box(ax.bbox)\n ell.set_alpha(0.5)\n ax.add_artist(ell)\n ax.set_aspect(\"equal\", \"datalim\")\n\n\niris = datasets.load_iris()\n\n# Break up the dataset into non-overlapping training (75%) and testing\n# (25%) sets.\nskf = StratifiedKFold(n_splits=4)\n# Only take the first fold.\ntrain_index, test_index = next(iter(skf.split(iris.data, iris.target)))\n\n\nX_train = iris.data[train_index]\ny_train = iris.target[train_index]\nX_test = iris.data[test_index]\ny_test = iris.target[test_index]\n\nn_classes = len(np.unique(y_train))\n\n# Try GMMs using different types of covariances.\nestimators = {\n cov_type: GaussianMixture(\n n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0\n )\n for cov_type in [\"spherical\", \"diag\", \"tied\", \"full\"]\n}\n\nn_estimators = len(estimators)\n\nplt.figure(figsize=(3 * n_estimators // 2, 6))\nplt.subplots_adjust(\n bottom=0.01, top=0.95, hspace=0.15, wspace=0.05, left=0.01, right=0.99\n)\n\n\nfor index, (name, estimator) in enumerate(estimators.items()):\n # Since we have class labels for the training data, we can\n # initialize the GMM parameters in a supervised manner.\n estimator.means_init = np.array(\n [X_train[y_train == i].mean(axis=0) for i in range(n_classes)]\n )\n\n # Train the other parameters using the EM algorithm.\n estimator.fit(X_train)\n\n h = plt.subplot(2, n_estimators // 2, index + 1)\n make_ellipses(estimator, h)\n\n for n, color in enumerate(colors):\n data = iris.data[iris.target == n]\n plt.scatter(\n data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n]\n )\n # Plot the test data with crosses\n for n, color in enumerate(colors):\n data = X_test[y_test == n]\n plt.scatter(data[:, 0], data[:, 1], marker=\"x\", color=color)\n\n y_train_pred = estimator.predict(X_train)\n train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100\n plt.text(0.05, 0.9, \"Train accuracy: %.1f\" % train_accuracy, transform=h.transAxes)\n\n y_test_pred = estimator.predict(X_test)\n test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100\n plt.text(0.05, 0.8, \"Test accuracy: %.1f\" % test_accuracy, transform=h.transAxes)\n\n plt.xticks(())\n plt.yticks(())\n plt.title(name)\n\nplt.legend(scatterpoints=1, loc=\"lower right\", prop=dict(size=12))\n\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }