{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# GMM Initialization Methods\n\nExamples of the different methods of initialization in Gaussian Mixture Models\n\nSee `gmm` for more information on the estimator.\n\nHere we generate some sample data with four easy to identify clusters. The\npurpose of this example is to show the four different methods for the\ninitialization parameter *init_param*.\n\nThe four initializations are *kmeans* (default), *random*, *random_from_data* and\n*k-means++*.\n\nOrange diamonds represent the initialization centers for the gmm generated by\nthe *init_param*. The rest of the data is represented as crosses and the\ncolouring represents the eventual associated classification after the GMM has\nfinished.\n\nThe numbers in the top right of each subplot represent the number of\niterations taken for the GaussianMixture to converge and the relative time\ntaken for the initialization part of the algorithm to run. The shorter\ninitialization times tend to have a greater number of iterations to converge.\n\nThe initialization time is the ratio of the time taken for that method versus\nthe time taken for the default *kmeans* method. As you can see all three\nalternative methods take less time to initialize when compared to *kmeans*.\n\nIn this example, when initialized with *random_from_data* or *random* the model takes\nmore iterations to converge. Here *k-means++* does a good job of both low\ntime to initialize and low number of GaussianMixture iterations to converge.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom timeit import default_timer as timer\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn.datasets._samples_generator import make_blobs\nfrom sklearn.mixture import GaussianMixture\nfrom sklearn.utils.extmath import row_norms\n\nprint(__doc__)\n\n# Generate some data\n\nX, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0)\nX = X[:, ::-1]\n\nn_samples = 4000\nn_components = 4\nx_squared_norms = row_norms(X, squared=True)\n\n\ndef get_initial_means(X, init_params, r):\n # Run a GaussianMixture with max_iter=0 to output the initialization means\n gmm = GaussianMixture(\n n_components=4, init_params=init_params, tol=1e-9, max_iter=0, random_state=r\n ).fit(X)\n return gmm.means_\n\n\nmethods = [\"kmeans\", \"random_from_data\", \"k-means++\", \"random\"]\ncolors = [\"navy\", \"turquoise\", \"cornflowerblue\", \"darkorange\"]\ntimes_init = {}\nrelative_times = {}\n\nplt.figure(figsize=(4 * len(methods) // 2, 6))\nplt.subplots_adjust(\n bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95\n)\n\nfor n, method in enumerate(methods):\n r = np.random.RandomState(seed=1234)\n plt.subplot(2, len(methods) // 2, n + 1)\n\n start = timer()\n ini = get_initial_means(X, method, r)\n end = timer()\n init_time = end - start\n\n gmm = GaussianMixture(\n n_components=4, means_init=ini, tol=1e-9, max_iter=2000, random_state=r\n ).fit(X)\n\n times_init[method] = init_time\n for i, color in enumerate(colors):\n data = X[gmm.predict(X) == i]\n plt.scatter(data[:, 0], data[:, 1], color=color, marker=\"x\")\n\n plt.scatter(\n ini[:, 0], ini[:, 1], s=75, marker=\"D\", c=\"orange\", lw=1.5, edgecolors=\"black\"\n )\n relative_times[method] = times_init[method] / times_init[methods[0]]\n\n plt.xticks(())\n plt.yticks(())\n plt.title(method, loc=\"left\", fontsize=12)\n plt.title(\n \"Iter %i | Init Time %.2fx\" % (gmm.n_iter_, relative_times[method]),\n loc=\"right\",\n fontsize=10,\n )\nplt.suptitle(\"GMM iterations and relative time taken to initialize\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }