{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Caching nearest neighbors\n\nThis example demonstrates how to precompute the k nearest neighbors before\nusing them in KNeighborsClassifier. KNeighborsClassifier can compute the\nnearest neighbors internally, but precomputing them can have several benefits,\nsuch as finer parameter control, caching for multiple use, or custom\nimplementations.\n\nHere we use the caching property of pipelines to cache the nearest neighbors\ngraph between multiple fits of KNeighborsClassifier. The first call is slow\nsince it computes the neighbors graph, while subsequent calls are faster as they\ndo not need to recompute the graph. Here the durations are small since the\ndataset is small, but the gain can be more substantial when the dataset grows\nlarger, or when the grid of parameter to search is large.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom tempfile import TemporaryDirectory\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn.datasets import load_digits\nfrom sklearn.model_selection import GridSearchCV\nfrom sklearn.neighbors import KNeighborsClassifier, KNeighborsTransformer\nfrom sklearn.pipeline import Pipeline\n\nX, y = load_digits(return_X_y=True)\nn_neighbors_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]\n\n# The transformer computes the nearest neighbors graph using the maximum number\n# of neighbors necessary in the grid search. The classifier model filters the\n# nearest neighbors graph as required by its own n_neighbors parameter.\ngraph_model = KNeighborsTransformer(n_neighbors=max(n_neighbors_list), mode=\"distance\")\nclassifier_model = KNeighborsClassifier(metric=\"precomputed\")\n\n# Note that we give `memory` a directory to cache the graph computation\n# that will be used several times when tuning the hyperparameters of the\n# classifier.\nwith TemporaryDirectory(prefix=\"sklearn_graph_cache_\") as tmpdir:\n full_model = Pipeline(\n steps=[(\"graph\", graph_model), (\"classifier\", classifier_model)], memory=tmpdir\n )\n\n param_grid = {\"classifier__n_neighbors\": n_neighbors_list}\n grid_model = GridSearchCV(full_model, param_grid)\n grid_model.fit(X, y)\n\n# Plot the results of the grid search.\nfig, axes = plt.subplots(1, 2, figsize=(8, 4))\naxes[0].errorbar(\n x=n_neighbors_list,\n y=grid_model.cv_results_[\"mean_test_score\"],\n yerr=grid_model.cv_results_[\"std_test_score\"],\n)\naxes[0].set(xlabel=\"n_neighbors\", title=\"Classification accuracy\")\naxes[1].errorbar(\n x=n_neighbors_list,\n y=grid_model.cv_results_[\"mean_fit_time\"],\n yerr=grid_model.cv_results_[\"std_fit_time\"],\n color=\"r\",\n)\naxes[1].set(xlabel=\"n_neighbors\", title=\"Fit time (with caching)\")\nfig.tight_layout()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }