{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# t-SNE: The effect of various perplexity values on the shape\n\nAn illustration of t-SNE on the two concentric circles and the S-curve\ndatasets for different perplexity values.\n\nWe observe a tendency towards clearer shapes as the perplexity value increases.\n\nThe size, the distance and the shape of clusters may vary upon initialization,\nperplexity values and does not always convey a meaning.\n\nAs shown below, t-SNE for higher perplexities finds meaningful topology of\ntwo concentric circles, however the size and the distance of the circles varies\nslightly from the original. Contrary to the two circles dataset, the shapes\nvisually diverge from S-curve topology on the S-curve dataset even for\nlarger perplexity values.\n\nFor further details, \"How to Use t-SNE Effectively\"\nhttps://distill.pub/2016/misread-tsne/ provides a good discussion of the\neffects of various parameters, as well as interactive plots to explore\nthose effects.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom time import time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom matplotlib.ticker import NullFormatter\n\nfrom sklearn import datasets, manifold\n\nn_samples = 150\nn_components = 2\n(fig, subplots) = plt.subplots(3, 5, figsize=(15, 8))\nperplexities = [5, 30, 50, 100]\n\nX, y = datasets.make_circles(\n n_samples=n_samples, factor=0.5, noise=0.05, random_state=0\n)\n\nred = y == 0\ngreen = y == 1\n\nax = subplots[0][0]\nax.scatter(X[red, 0], X[red, 1], c=\"r\")\nax.scatter(X[green, 0], X[green, 1], c=\"g\")\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\nplt.axis(\"tight\")\n\nfor i, perplexity in enumerate(perplexities):\n ax = subplots[0][i + 1]\n\n t0 = time()\n tsne = manifold.TSNE(\n n_components=n_components,\n init=\"random\",\n random_state=0,\n perplexity=perplexity,\n max_iter=300,\n )\n Y = tsne.fit_transform(X)\n t1 = time()\n print(\"circles, perplexity=%d in %.2g sec\" % (perplexity, t1 - t0))\n ax.set_title(\"Perplexity=%d\" % perplexity)\n ax.scatter(Y[red, 0], Y[red, 1], c=\"r\")\n ax.scatter(Y[green, 0], Y[green, 1], c=\"g\")\n ax.xaxis.set_major_formatter(NullFormatter())\n ax.yaxis.set_major_formatter(NullFormatter())\n ax.axis(\"tight\")\n\n# Another example using s-curve\nX, color = datasets.make_s_curve(n_samples, random_state=0)\n\nax = subplots[1][0]\nax.scatter(X[:, 0], X[:, 2], c=color)\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\n\nfor i, perplexity in enumerate(perplexities):\n ax = subplots[1][i + 1]\n\n t0 = time()\n tsne = manifold.TSNE(\n n_components=n_components,\n init=\"random\",\n random_state=0,\n perplexity=perplexity,\n learning_rate=\"auto\",\n max_iter=300,\n )\n Y = tsne.fit_transform(X)\n t1 = time()\n print(\"S-curve, perplexity=%d in %.2g sec\" % (perplexity, t1 - t0))\n\n ax.set_title(\"Perplexity=%d\" % perplexity)\n ax.scatter(Y[:, 0], Y[:, 1], c=color)\n ax.xaxis.set_major_formatter(NullFormatter())\n ax.yaxis.set_major_formatter(NullFormatter())\n ax.axis(\"tight\")\n\n\n# Another example using a 2D uniform grid\nx = np.linspace(0, 1, int(np.sqrt(n_samples)))\nxx, yy = np.meshgrid(x, x)\nX = np.hstack(\n [\n xx.ravel().reshape(-1, 1),\n yy.ravel().reshape(-1, 1),\n ]\n)\ncolor = xx.ravel()\nax = subplots[2][0]\nax.scatter(X[:, 0], X[:, 1], c=color)\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\n\nfor i, perplexity in enumerate(perplexities):\n ax = subplots[2][i + 1]\n\n t0 = time()\n tsne = manifold.TSNE(\n n_components=n_components,\n init=\"random\",\n random_state=0,\n perplexity=perplexity,\n max_iter=400,\n )\n Y = tsne.fit_transform(X)\n t1 = time()\n print(\"uniform grid, perplexity=%d in %.2g sec\" % (perplexity, t1 - t0))\n\n ax.set_title(\"Perplexity=%d\" % perplexity)\n ax.scatter(Y[:, 0], Y[:, 1], c=color)\n ax.xaxis.set_major_formatter(NullFormatter())\n ax.yaxis.set_major_formatter(NullFormatter())\n ax.axis(\"tight\")\n\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }