{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Hierarchical clustering: structured vs unstructured ward\n\nExample builds a swiss roll dataset and runs\nhierarchical clustering on their position.\n\nFor more information, see `hierarchical_clustering`.\n\nIn a first step, the hierarchical clustering is performed without connectivity\nconstraints on the structure and is solely based on distance, whereas in\na second step the clustering is restricted to the k-Nearest Neighbors\ngraph: it's a hierarchical clustering with structure prior.\n\nSome of the clusters learned without connectivity constraints do not\nrespect the structure of the swiss roll and extend across different folds of\nthe manifolds. On the opposite, when opposing connectivity constraints,\nthe clusters form a nice parcellation of the swiss roll.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport time as time\n\n# The following import is required\n# for 3D projection to work with matplotlib < 3.2\nimport mpl_toolkits.mplot3d # noqa: F401\nimport numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\nWe start by generating the Swiss Roll dataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import make_swiss_roll\n\nn_samples = 1500\nnoise = 0.05\nX, _ = make_swiss_roll(n_samples, noise=noise)\n# Make it thinner\nX[:, 1] *= 0.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute clustering\n\nWe perform AgglomerativeClustering which comes under Hierarchical Clustering\nwithout any connectivity constraints.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.cluster import AgglomerativeClustering\n\nprint(\"Compute unstructured hierarchical clustering...\")\nst = time.time()\nward = AgglomerativeClustering(n_clusters=6, linkage=\"ward\").fit(X)\nelapsed_time = time.time() - st\nlabel = ward.labels_\nprint(f\"Elapsed time: {elapsed_time:.2f}s\")\nprint(f\"Number of points: {label.size}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot result\nPlotting the unstructured hierarchical clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nfig1 = plt.figure()\nax1 = fig1.add_subplot(111, projection=\"3d\", elev=7, azim=-80)\nax1.set_position([0, 0, 0.95, 1])\nfor l in np.unique(label):\n ax1.scatter(\n X[label == l, 0],\n X[label == l, 1],\n X[label == l, 2],\n color=plt.cm.jet(float(l) / np.max(label + 1)),\n s=20,\n edgecolor=\"k\",\n )\n_ = fig1.suptitle(f\"Without connectivity constraints (time {elapsed_time:.2f}s)\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## We are defining k-Nearest Neighbors with 10 neighbors\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.neighbors import kneighbors_graph\n\nconnectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute clustering\n\nWe perform AgglomerativeClustering again with connectivity constraints.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Compute structured hierarchical clustering...\")\nst = time.time()\nward = AgglomerativeClustering(\n n_clusters=6, connectivity=connectivity, linkage=\"ward\"\n).fit(X)\nelapsed_time = time.time() - st\nlabel = ward.labels_\nprint(f\"Elapsed time: {elapsed_time:.2f}s\")\nprint(f\"Number of points: {label.size}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot result\n\nPlotting the structured hierarchical clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig2 = plt.figure()\nax2 = fig2.add_subplot(121, projection=\"3d\", elev=7, azim=-80)\nax2.set_position([0, 0, 0.95, 1])\nfor l in np.unique(label):\n ax2.scatter(\n X[label == l, 0],\n X[label == l, 1],\n X[label == l, 2],\n color=plt.cm.jet(float(l) / np.max(label + 1)),\n s=20,\n edgecolor=\"k\",\n )\nfig2.suptitle(f\"With connectivity constraints (time {elapsed_time:.2f}s)\")\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }