{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Nearest Neighbors Classification\n\nThis example shows how to use :class:`~sklearn.neighbors.KNeighborsClassifier`.\nWe train such a classifier on the iris dataset and observe the difference of the\ndecision boundary obtained with regards to the parameter `weights`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the data\n\nIn this example, we use the iris dataset. We split the data into a train and test\ndataset.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import load_iris\nfrom sklearn.model_selection import train_test_split\n\niris = load_iris(as_frame=True)\nX = iris.data[[\"sepal length (cm)\", \"sepal width (cm)\"]]\ny = iris.target\nX_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## K-nearest neighbors classifier\n\nWe want to use a k-nearest neighbors classifier considering a neighborhood of 11 data\npoints. Since our k-nearest neighbors model uses euclidean distance to find the\nnearest neighbors, it is therefore important to scale the data beforehand. Refer to\nthe example entitled\n`sphx_glr_auto_examples_preprocessing_plot_scaling_importance.py` for more\ndetailed information.\n\nThus, we use a :class:`~sklearn.pipeline.Pipeline` to chain a scaler before to use\nour classifier.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.preprocessing import StandardScaler\n\nclf = Pipeline(\n steps=[(\"scaler\", StandardScaler()), (\"knn\", KNeighborsClassifier(n_neighbors=11))]\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision boundary\n\nNow, we fit two classifiers with different values of the parameter\n`weights`. We plot the decision boundary of each classifier as well as the original\ndataset to observe the difference.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nfrom sklearn.inspection import DecisionBoundaryDisplay\n\n_, axs = plt.subplots(ncols=2, figsize=(12, 5))\n\nfor ax, weights in zip(axs, (\"uniform\", \"distance\")):\n clf.set_params(knn__weights=weights).fit(X_train, y_train)\n disp = DecisionBoundaryDisplay.from_estimator(\n clf,\n X_test,\n response_method=\"predict\",\n plot_method=\"pcolormesh\",\n xlabel=iris.feature_names[0],\n ylabel=iris.feature_names[1],\n shading=\"auto\",\n alpha=0.5,\n ax=ax,\n )\n scatter = disp.ax_.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors=\"k\")\n disp.ax_.legend(\n scatter.legend_elements()[0],\n iris.target_names,\n loc=\"lower left\",\n title=\"Classes\",\n )\n _ = disp.ax_.set_title(\n f\"3-Class classification\\n(k={clf[-1].n_neighbors}, weights={weights!r})\"\n )\n\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n\nWe observe that the parameter `weights` has an impact on the decision boundary. When\n`weights=\"unifom\"` all nearest neighbors will have the same impact on the decision.\nWhereas when `weights=\"distance\"` the weight given to each neighbor is proportional\nto the inverse of the distance from that neighbor to the query point.\n\nIn some cases, taking the distance into account might improve the model.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }