{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Label Propagation digits active learning\n\nDemonstrates an active learning technique to learn handwritten digits\nusing label propagation.\n\nWe start by training a label propagation model with only 10 labeled points,\nthen we select the top five most uncertain points to label. Next, we train\nwith 15 labeled points (original 10 + 5 new ones). We repeat this process\nfour times to have a model trained with 30 labeled examples. Note you can\nincrease this to label more than 30 by changing `max_iterations`. Labeling\nmore than 30 can be useful to get a sense for the speed of convergence of\nthis active learning technique.\n\nA plot will appear showing the top 5 most uncertain digits for each iteration\nof training. These may or may not contain mistakes, but we will train the next\nmodel with their true labels.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy import stats\n\nfrom sklearn import datasets\nfrom sklearn.metrics import classification_report, confusion_matrix\nfrom sklearn.semi_supervised import LabelSpreading\n\ndigits = datasets.load_digits()\nrng = np.random.RandomState(0)\nindices = np.arange(len(digits.data))\nrng.shuffle(indices)\n\nX = digits.data[indices[:330]]\ny = digits.target[indices[:330]]\nimages = digits.images[indices[:330]]\n\nn_total_samples = len(y)\nn_labeled_points = 40\nmax_iterations = 5\n\nunlabeled_indices = np.arange(n_total_samples)[n_labeled_points:]\nf = plt.figure()\n\nfor i in range(max_iterations):\n if len(unlabeled_indices) == 0:\n print(\"No unlabeled items left to label.\")\n break\n y_train = np.copy(y)\n y_train[unlabeled_indices] = -1\n\n lp_model = LabelSpreading(gamma=0.25, max_iter=20)\n lp_model.fit(X, y_train)\n\n predicted_labels = lp_model.transduction_[unlabeled_indices]\n true_labels = y[unlabeled_indices]\n\n cm = confusion_matrix(true_labels, predicted_labels, labels=lp_model.classes_)\n\n print(\"Iteration %i %s\" % (i, 70 * \"_\"))\n print(\n \"Label Spreading model: %d labeled & %d unlabeled (%d total)\"\n % (n_labeled_points, n_total_samples - n_labeled_points, n_total_samples)\n )\n\n print(classification_report(true_labels, predicted_labels))\n\n print(\"Confusion matrix\")\n print(cm)\n\n # compute the entropies of transduced label distributions\n pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)\n\n # select up to 5 digit examples that the classifier is most uncertain about\n uncertainty_index = np.argsort(pred_entropies)[::-1]\n uncertainty_index = uncertainty_index[\n np.isin(uncertainty_index, unlabeled_indices)\n ][:5]\n\n # keep track of indices that we get labels for\n delete_indices = np.array([], dtype=int)\n\n # for more than 5 iterations, visualize the gain only on the first 5\n if i < 5:\n f.text(\n 0.05,\n (1 - (i + 1) * 0.183),\n \"model %d\\n\\nfit with\\n%d labels\" % ((i + 1), i * 5 + 10),\n size=10,\n )\n for index, image_index in enumerate(uncertainty_index):\n image = images[image_index]\n\n # for more than 5 iterations, visualize the gain only on the first 5\n if i < 5:\n sub = f.add_subplot(5, 5, index + 1 + (5 * i))\n sub.imshow(image, cmap=plt.cm.gray_r, interpolation=\"none\")\n sub.set_title(\n \"predict: %i\\ntrue: %i\"\n % (lp_model.transduction_[image_index], y[image_index]),\n size=10,\n )\n sub.axis(\"off\")\n\n # labeling 5 points, remote from labeled set\n (delete_index,) = np.where(unlabeled_indices == image_index)\n delete_indices = np.concatenate((delete_indices, delete_index))\n\n unlabeled_indices = np.delete(unlabeled_indices, delete_indices)\n n_labeled_points += len(uncertainty_index)\n\nf.suptitle(\n (\n \"Active learning with Label Propagation.\\nRows show 5 most \"\n \"uncertain labels to learn with the next model.\"\n ),\n y=1.15,\n)\nplt.subplots_adjust(left=0.2, bottom=0.03, right=0.9, top=0.9, wspace=0.2, hspace=0.85)\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }