{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Neighborhood Components Analysis Illustration\n\nThis example illustrates a learned distance metric that maximizes\nthe nearest neighbors classification accuracy. It provides a visual\nrepresentation of this metric compared to the original point\nspace. Please refer to the `User Guide ` for more information.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom matplotlib import cm\nfrom scipy.special import logsumexp\n\nfrom sklearn.datasets import make_classification\nfrom sklearn.neighbors import NeighborhoodComponentsAnalysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Original points\nFirst we create a data set of 9 samples from 3 classes, and plot the points\nin the original space. For this example, we focus on the classification of\npoint no. 3. The thickness of a link between point no. 3 and another point\nis proportional to their distance.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X, y = make_classification(\n n_samples=9,\n n_features=2,\n n_informative=2,\n n_redundant=0,\n n_classes=3,\n n_clusters_per_class=1,\n class_sep=1.0,\n random_state=0,\n)\n\nplt.figure(1)\nax = plt.gca()\nfor i in range(X.shape[0]):\n ax.text(X[i, 0], X[i, 1], str(i), va=\"center\", ha=\"center\")\n ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)\n\nax.set_title(\"Original points\")\nax.axes.get_xaxis().set_visible(False)\nax.axes.get_yaxis().set_visible(False)\nax.axis(\"equal\") # so that boundaries are displayed correctly as circles\n\n\ndef link_thickness_i(X, i):\n diff_embedded = X[i] - X\n dist_embedded = np.einsum(\"ij,ij->i\", diff_embedded, diff_embedded)\n dist_embedded[i] = np.inf\n\n # compute exponentiated distances (use the log-sum-exp trick to\n # avoid numerical instabilities\n exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))\n return exp_dist_embedded\n\n\ndef relate_point(X, i, ax):\n pt_i = X[i]\n for j, pt_j in enumerate(X):\n thickness = link_thickness_i(X, i)\n if i != j:\n line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])\n ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])\n\n\ni = 3\nrelate_point(X, i, ax)\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning an embedding\nWe use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an\nembedding and plot the points after the transformation. We then take the\nembedding and find the nearest neighbors.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)\nnca = nca.fit(X, y)\n\nplt.figure(2)\nax2 = plt.gca()\nX_embedded = nca.transform(X)\nrelate_point(X_embedded, i, ax2)\n\nfor i in range(len(X)):\n ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va=\"center\", ha=\"center\")\n ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)\n\nax2.set_title(\"NCA embedding\")\nax2.axes.get_xaxis().set_visible(False)\nax2.axes.get_yaxis().set_visible(False)\nax2.axis(\"equal\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }