{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Compare BIRCH and MiniBatchKMeans\n\nThis example compares the timing of BIRCH (with and without the global\nclustering step) and MiniBatchKMeans on a synthetic dataset having\n25,000 samples and 2 features generated using make_blobs.\n\nBoth ``MiniBatchKMeans`` and ``BIRCH`` are very scalable algorithms and could\nrun efficiently on hundreds of thousands or even millions of datapoints. We\nchose to limit the dataset size of this example in the interest of keeping\nour Continuous Integration resource usage reasonable but the interested\nreader might enjoy editing this script to rerun it with a larger value for\n`n_samples`.\n\nIf ``n_clusters`` is set to None, the data is reduced from 25,000\nsamples to a set of 158 clusters. This can be viewed as a preprocessing\nstep before the final (global) clustering step that further reduces these\n158 clusters to 100 clusters.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom itertools import cycle\nfrom time import time\n\nimport matplotlib.colors as colors\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom joblib import cpu_count\n\nfrom sklearn.cluster import Birch, MiniBatchKMeans\nfrom sklearn.datasets import make_blobs\n\n# Generate centers for the blobs so that it forms a 10 X 10 grid.\nxx = np.linspace(-22, 22, 10)\nyy = np.linspace(-22, 22, 10)\nxx, yy = np.meshgrid(xx, yy)\nn_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))\n\n# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.\nX, y = make_blobs(n_samples=25000, centers=n_centers, random_state=0)\n\n# Use all colors that matplotlib provides by default.\ncolors_ = cycle(colors.cnames.keys())\n\nfig = plt.figure(figsize=(12, 4))\nfig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)\n\n# Compute clustering with BIRCH with and without the final clustering step\n# and plot.\nbirch_models = [\n Birch(threshold=1.7, n_clusters=None),\n Birch(threshold=1.7, n_clusters=100),\n]\nfinal_step = [\"without global clustering\", \"with global clustering\"]\n\nfor ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):\n t = time()\n birch_model.fit(X)\n print(\"BIRCH %s as the final step took %0.2f seconds\" % (info, (time() - t)))\n\n # Plot result\n labels = birch_model.labels_\n centroids = birch_model.subcluster_centers_\n n_clusters = np.unique(labels).size\n print(\"n_clusters : %d\" % n_clusters)\n\n ax = fig.add_subplot(1, 3, ind + 1)\n for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):\n mask = labels == k\n ax.scatter(X[mask, 0], X[mask, 1], c=\"w\", edgecolor=col, marker=\".\", alpha=0.5)\n if birch_model.n_clusters is None:\n ax.scatter(this_centroid[0], this_centroid[1], marker=\"+\", c=\"k\", s=25)\n ax.set_ylim([-25, 25])\n ax.set_xlim([-25, 25])\n ax.set_autoscaley_on(False)\n ax.set_title(\"BIRCH %s\" % info)\n\n# Compute clustering with MiniBatchKMeans.\nmbk = MiniBatchKMeans(\n init=\"k-means++\",\n n_clusters=100,\n batch_size=256 * cpu_count(),\n n_init=10,\n max_no_improvement=10,\n verbose=0,\n random_state=0,\n)\nt0 = time()\nmbk.fit(X)\nt_mini_batch = time() - t0\nprint(\"Time taken to run MiniBatchKMeans %0.2f seconds\" % t_mini_batch)\nmbk_means_labels_unique = np.unique(mbk.labels_)\n\nax = fig.add_subplot(1, 3, 3)\nfor this_centroid, k, col in zip(mbk.cluster_centers_, range(n_clusters), colors_):\n mask = mbk.labels_ == k\n ax.scatter(X[mask, 0], X[mask, 1], marker=\".\", c=\"w\", edgecolor=col, alpha=0.5)\n ax.scatter(this_centroid[0], this_centroid[1], marker=\"+\", c=\"k\", s=25)\nax.set_xlim([-25, 25])\nax.set_ylim([-25, 25])\nax.set_title(\"MiniBatchKMeans\")\nax.set_autoscaley_on(False)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }