{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# [NTDS'19] tutorial 5: machine learning with scikit-learn\n", "[ntds'19]: https://github.com/mdeff/ntds_2019\n", "\n", "[Nicolas Aspert](https://people.epfl.ch/nicolas.aspert), [EPFL LTS2](https://lts2.epfl.ch).\n", "\n", "* Dataset: [digits](https://archive.ics.uci.edu/ml/datasets/Pen-Based+Recognition+of+Handwritten+Digits)\n", "* Tools: [scikit-learn](https://scikit-learn.org/stable/), [numpy](http://www.numpy.org), [scipy](https://www.scipy.org), [matplotlib](https://matplotlib.org)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*scikit-learn* is a machine learning python library. Most commonly used algorithms for classification, clustering and regression are implemented as part of the library, e.g.\n", "* [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)\n", "* [k-means clustering](https://en.wikipedia.org/wiki/K-means_clustering)\n", "* [Support vector machines](https://en.wikipedia.org/wiki/Support-vector_machine)\n", "* ...\n", "\n", "The aim of this tutorial is to show basic usage of some simple machine learning techniques. \n", "Check the official [documentation](https://scikit-learn.org/stable/documentation.html) for more information, especially the [tutorials](https://scikit-learn.org/stable/tutorial/index.html) section." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "import sklearn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data loading\n", "\n", "We will use a dataset named *digits*.\n", "It is made of 1797 handwritten digits images (of size 8x8 pixels each) acquired from 44 different writers. \n", "Each image is labelled according to the digit present in the image.\n", "\n", "You can find more information about this dataset [here](https://archive.ics.uci.edu/ml/datasets/Pen-Based+Recognition+of+Handwritten+Digits).\n", "\n", "![digits](https://scikit-learn.org/stable/_images/sphx_glr_plot_lle_digits_001.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "\n", "digits = load_digits()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `digits` variable contains several fields.\n", "\n", "In `images` you have all samples as 2-dimensional arrays." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1797, 8, 8)\n", "[[ 0. 0. 5. 13. 9. 1. 0. 0.]\n", " [ 0. 0. 13. 15. 10. 15. 5. 0.]\n", " [ 0. 3. 15. 2. 0. 11. 8. 0.]\n", " [ 0. 4. 12. 0. 0. 8. 8. 0.]\n", " [ 0. 5. 8. 0. 0. 9. 8. 0.]\n", " [ 0. 4. 11. 0. 1. 12. 7. 0.]\n", " [ 0. 2. 14. 5. 10. 12. 0. 0.]\n", " [ 0. 0. 6. 13. 10. 0. 0. 0.]]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAKyklEQVR4nO3dX4hc5RnH8d+vUWn9h6G1RXZD44oEpFBjQkACQmNaYhXtRQ0JKFQK642itKCxd73zSuxFEULUCqZKNyqIWG2CihVa626StsaNJV0s2UQbxUjUQkPi04udQNS1e2bmnPecffx+YHF3dsj7TDZfz8zszHkdEQKQx1faHgBAvYgaSIaogWSIGkiGqIFkzmjiD7Wd8in1pUuXFl1vZGSk2FrHjh0rttahQ4eKrXXy5Mlia5UWEZ7v8kaizmr9+vVF17v33nuLrbVr165ia23ZsqXYWkePHi22Vldw9xtIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZS1LY32H7T9gHb5V4OBKBvC0Zte4mkX0u6RtJlkjbbvqzpwQAMpsqReo2kAxExExHHJT0u6YZmxwIwqCpRj0g6eNrXs73LPsX2uO1J25N1DQegf1XepTXf27s+99bKiNgqaauU962XwGJQ5Ug9K2nZaV+PSjrczDgAhlUl6tckXWr7YttnSdok6elmxwIwqAXvfkfECdu3SXpe0hJJD0XEvsYnAzCQSmc+iYhnJT3b8CwAasAryoBkiBpIhqiBZIgaSIaogWSIGkiGqIFk2KGjDyV3zJCksbGxYmuV3FLo/fffL7bWxo0bi60lSRMTE0XXmw9HaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkqmyQ8dDto/Yfr3EQACGU+VI/RtJGxqeA0BNFow6Il6WVO4V+ACGUtu7tGyPSxqv688DMJjaombbHaAbePYbSIaogWSq/ErrMUl/krTC9qztnzY/FoBBVdlLa3OJQQDUg7vfQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDKLftudVatWFVur5DY4knTJJZcUW2tmZqbYWjt37iy2Vsl/HxLb7gBoAFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8lUOUfZMtsv2p62vc/2HSUGAzCYKq/9PiHp5xGx2/Z5kqZs74yINxqeDcAAqmy783ZE7O59/qGkaUkjTQ8GYDB9vUvL9nJJKyW9Os/32HYH6IDKUds+V9ITku6MiGOf/T7b7gDdUOnZb9tnai7o7RHxZLMjARhGlWe/LelBSdMRcV/zIwEYRpUj9VpJN0taZ3tv7+OHDc8FYEBVtt15RZILzAKgBryiDEiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkFv1eWkuXLi221tTUVLG1pLL7W5VU+u/xy4YjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJUTD37V9l9s/7W37c4vSwwGYDBVXib6X0nrIuKj3qmCX7H9+4j4c8OzARhAlRMPhqSPel+e2fvgZP1AR1U9mf8S23slHZG0MyLm3XbH9qTtybqHBFBdpagj4mREXC5pVNIa29+Z5zpbI2J1RKyue0gA1fX17HdEfCDpJUkbGpkGwNCqPPt9oe0Lep9/TdJ6SfubHgzAYKo8+32RpEdsL9Hc/wR+FxHPNDsWgEFVefb7b5rbkxrAIsAryoBkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhm13+rBr165ia2VW8md29OjRYmt1BUdqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSqRx174T+e2xz0kGgw/o5Ut8habqpQQDUo+q2O6OSrpW0rdlxAAyr6pH6fkl3Sfrki67AXlpAN1TZoeM6SUciYur/XY+9tIBuqHKkXivpettvSXpc0jrbjzY6FYCBLRh1RNwTEaMRsVzSJkkvRMRNjU8GYCD8nhpIpq/TGUXES5rbyhZAR3GkBpIhaiAZogaSIWogGaIGkiFqIBmiBpJZ9NvulNxWZdWqVcXWKq3kVjgl/x4nJiaKrdUVHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkim0stEe2cS/VDSSUknOA0w0F39vPb7exHxXmOTAKgFd7+BZKpGHZL+YHvK9vh8V2DbHaAbqt79XhsRh21/U9JO2/sj4uXTrxARWyVtlSTbUfOcACqqdKSOiMO9/x6R9JSkNU0OBWBwVTbIO8f2eac+l/QDSa83PRiAwVS5+/0tSU/ZPnX930bEc41OBWBgC0YdETOSvltgFgA14FdaQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDKOqP9l2iVf+z02NlZqKU1Oln2vyq233lpsrRtvvLHYWiV/ZqtX533rf0R4vss5UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEylqG1fYHuH7f22p21f2fRgAAZT9bzfv5L0XET82PZZks5ucCYAQ1gwatvnS7pK0k8kKSKOSzre7FgABlXl7veYpHclPWx7j+1tvfN/fwrb7gDdUCXqMyRdIemBiFgp6WNJWz57pYjYGhGr2eYWaFeVqGclzUbEq72vd2gucgAdtGDUEfGOpIO2V/QuulrSG41OBWBgVZ/9vl3S9t4z3zOSbmluJADDqBR1ROyVxGNlYBHgFWVAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJLPo99IqaXx8vOh6d999d7G1pqamiq21cePGYmtlxl5awJcEUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQzIJR215he+9pH8ds31liOAD9W/AcZRHxpqTLJcn2EkmHJD3V8FwABtTv3e+rJf0zIv7VxDAAhlf1FMGnbJL02HzfsD0uqew7HgB8TuUjde+c39dLmpjv+2y7A3RDP3e/r5G0OyL+3dQwAIbXT9Sb9QV3vQF0R6WobZ8t6fuSnmx2HADDqrrtzn8kfb3hWQDUgFeUAckQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZBMU9vuvCup37dnfkPSe7UP0w1Zbxu3qz3fjogL5/tGI1EPwvZk1nd4Zb1t3K5u4u43kAxRA8l0KeqtbQ/QoKy3jdvVQZ15TA2gHl06UgOoAVEDyXQiatsbbL9p+4DtLW3PUwfby2y/aHva9j7bd7Q9U51sL7G9x/Yzbc9SJ9sX2N5he3/vZ3dl2zP1q/XH1L0NAv6hudMlzUp6TdLmiHij1cGGZPsiSRdFxG7b50makvSjxX67TrH9M0mrJZ0fEde1PU9dbD8i6Y8Rsa13Bt2zI+KDtufqRxeO1GskHYiImYg4LulxSTe0PNPQIuLtiNjd+/xDSdOSRtqdqh62RyVdK2lb27PUyfb5kq6S9KAkRcTxxRa01I2oRyQdPO3rWSX5x3+K7eWSVkp6td1JanO/pLskfdL2IDUbk/SupId7Dy222T6n7aH61YWoPc9laX7PZvtcSU9IujMijrU9z7BsXyfpSERMtT1LA86QdIWkByJipaSPJS2653i6EPWspGWnfT0q6XBLs9TK9pmaC3p7RGQ5vfJaSdfbfktzD5XW2X603ZFqMytpNiJO3aPaobnIF5UuRP2apEttX9x7YmKTpKdbnmlotq25x2bTEXFf2/PUJSLuiYjRiFiuuZ/VCxFxU8tj1SIi3pF00PaK3kVXS1p0T2z2u0Fe7SLihO3bJD0vaYmkhyJiX8tj1WGtpJsl/d323t5lv4iIZ1ucCQu7XdL23gFmRtItLc/Tt9Z/pQWgXl24+w2gRkQNJEPUQDJEDSRD1EAyRA0kQ9RAMv8DNH2NFu1/p/oAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(digits.images.shape)\n", "print(digits.images[0])\n", "plt.imshow(digits.images[0], cmap=plt.cm.gray);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In `data`, the same samples are represented as 1-d vectors of length 64." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1797, 64)\n", "[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3.\n", " 15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0.\n", " 0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12.\n", " 0. 0. 0. 0. 6. 13. 10. 0. 0. 0.]\n" ] } ], "source": [ "print(digits.data.shape)\n", "print(digits.data[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " In `target` you have the label corresponding to each image." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1797,)\n", "[0 1 2 ... 8 9 8]\n" ] } ], "source": [ "print(digits.target.shape)\n", "print(digits.target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us visualize the 20 first entries of the dataset (image display kept small on purpose)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0sAAAA5CAYAAADnRrAhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQO0lEQVR4nO3df+xd9V3H8dd7lLgBQlt/EJ1KC9ncorEl7V8abBupOI1SxBKUbaWJsYGwtI0z7R9boGxmbWIsZL9sE2w7MSY0waJzcQEpzVgytbVgsgzJpOCYNBujdMAAFd/+cW6v7/Peuaf3+90553Ohz0fyTc/t+bb3fc/5/LjnfN6fzzF3FwAAAACg7m2lAwAAAACAWcTFEgAAAAA04GIJAAAAABpwsQQAAAAADbhYAgAAAIAGXCwBAAAAQAMulgAAAACgQacXS2a22Mz+2sxeMbNnzOz3uvz/u2Jmt5nZUTN73cz2l46niZn9kJndMzqOL5nZcTN7X+m4mpjZvWb2nJl918yeNLPfLx1TGzN7l5m9Zmb3lo5lEjN7ZBTjy6Offysd0yRmdqOZfW1U7//dzK4qHVMUjuGZnzfM7JOl42piZkvM7AtmdsrMTprZp8xsQem4mpjZe83sYTM7bWZfN7PrSsfUhH6pO/RL/aFf6hb9Unfol7ofWfq0pP+SdKmkmyR91sx+ruP36MJ/Svq4pD8vHUiLBZK+IWmVpEskfVTSfWa2pGBMk3xC0hJ3v1jSb0n6uJmtKBxTm09L+ufSQUzhNne/aPTzs6WDaWJmayXtkrRR0g9L+mVJTxUNKgnH8CJVbdOrkg4WDmuSz0j6lqSfkLRcVf2/tWhEDUYd5QOSPi9psaQ/kHSvmb27aGDN6Je6Q7/UH/qljtAvde6c75c6u1gyswslXS/po+7+srs/KulvJH2gq/foirvf7+6HJH2ndCyTuPsr7n6Huz/t7v/r7p+XdELSzDX27v5Vd3/9zMvRzxUFQ5rIzG6U9KKkfygdy1vEDkl3uvtXRuX0m+7+zdJBtfgdVY3+l0oHMsFSSfe5+2vuflLS30uaxS/275H0k5J2u/sb7v6wpC9rxtp7+qVu0S/1g36pc/RL3Trn+6UuR5beLekNd38y/N3jms0D+qZjZpeqOsZfLR1LEzP7jJl9T9ITkp6T9IXCIX0fM7tY0p2S/rB0LFP6hJk9b2ZfNrPVpYPJzOw8SSsl/dhouPvZ0fD8O0rH1mKDpM+5u5cOZIK7Jd1oZheY2TslvU9VxzRrbMLf/fzQgZwF/VKP6Jd+cPRL3aJf6sU53y91ebF0kaTT6e9OqxoCxQ/AzM6X9JeSDrj7E6XjaeLut6o611dJul/S6+3/ooiPSbrH3b9ROpApbJN0uaR3Stor6W/NbNbuil4q6XxVd8WuUjU8f6Wkj5QMahIz+xlV6QMHSsfS4oiqL/LflfSspKOSDhWNqNkTqu6E/pGZnW9mv6rq2F5QNqzvQ7/UE/qlztAvdYt+qXvnfL/U5cXSy5IuTn93saSXOnyPc46ZvU3SX6jKub+tcDitRsOej0r6KUm3lI4nMrPlkq6WtLt0LNNw939095fc/XV3P6BqKPnXS8eVvDr685Pu/py7Py/pTzV7cZ7xQUmPuvuJ0oE0GdX1L6r6UnehpB+VtEhV7v1Mcff/lrRO0m9IOqnqrvh9qjrSWUK/1AP6pW7QL/WCfqlD9EuVLi+WnpS0wMzeFf5umWZ0eP7NwMxM0j2q7pRcPyoIbwYLNHu54aslLZH0H2Z2UtKHJV1vZv9SMqg5cDUPMRfj7qdUNUKzmjqQfVCzffdusaSflvSp0ZeR70japxnt5N39X919lbv/iLtfo+qO8z+VjiuhX+oY/VKnVot+qVP0S52jX1KHF0vu/oqqK887zexCM/slSdequvs0U8xsgZm9XdJ5ks4zs7fP6DKIn5X0Xkm/6e6vnu2XSzCzHx8t0XmRmZ1nZtdI+l1JD5eOLdmrqqNcPvr5M0l/J+makkE1MbOFZnbNmXJpZjepWs3ni6Vja7BP0odG5WCRpC2qVqKZKWb2i6pSR2Z1tSGN7oCekHTL6LwvVJXL/njZyJqZ2S+MyugFZvZhVSsl7S8cVg39Ui/ol7pDv9QP+qWO0C9Vul46/FZJ71CVM/hXkm5x91m8g/cRVUO12yW9f7Q9U/msZnaZpE2qGtCT9v9r8d9UOLTMVaU2PCvplKQ/kbTF3R8oGlXi7t9z95NnflSl57zm7t8uHVuD81UtIfxtSc9L+pCkde4+i8+0+Jiq5W6flPQ1Sccl/XHRiJptkHS/u896+tVvS/o1Vef+65L+R9LWohFN9gFVk+a/JelXJK0Nq4/NEvqljtAvdYt+qTf0S9065/slm93FNwAAAACgnK5HlgAAAADgLYGLJQAAAABowMUSAAAAADTgYgkAAAAAGnCxBAAAAAANWp/hYGYTl8pbv379eHvnzp21fQ899NB4e/v27bV9p06dmvh+7j6vh5u1xRk98sgjtdcLFy4cb99+++21fQ88MHmF0fnGKU0f6+rVq2uvDx06NN5+7LHHWn836uOYbtu2bbydz/1TTz013l65cmVtX8lzH8+1JO3fv3+8vW7duqnfr484Y7l8+umna/tuvvnm+bxd0bq0fPnyqd+vr7q0ZcuW8XY+9/F8L1u2rLbv9OnT4+0lS5bU9p06darzY3rXXXc1xiXVy2j8PUl68cUXJ75fH+c+tj/5eLa1P236iDPGdscdd9T2xbqUy29bG9B3XcpiG5DPczzWeV9fdenaa68db2/dWl8tOB63tjKZdXFMc/2MdT63mzG2WJalej3LfWvf5z6X0fgZ8ufru87H8yzVz3Wu87ndjJYuXTrezv3ZW/14StO3TzG2/DqX0bbvAX0c07Y60daO5s8U9RFnjGUu7X2bSXEysgQAAAAADeb9dPA4onD55ZfX9i1atGi8/cILL9T23XDDDePtgweHfWhxvpOwatWq8faaNWtq+9pGlvoS78gfPny4tq/tjnff8uhRHFXctGlTbd+ePXvG2ytWrKjtiyOOQ8t3ZvLdkpLi+YxlUpI2bNgw3n7mmWcm/rshxDuPOc4dO3YMGstc5Hof737lO2Hxzt9c7pTPV9soXCyzefRmvqM508plK991juKz+h5/vP5Q97mMMnYh3hHNMccymtuD+Dr+H0PIcV522WWN29Lw5VOSDhw4MPE943HLo599y2U01okcSzxumzdvru2Ln2mIfiHGksthHomZ9O/6OPcbN26svY5tfPz+IdXrUr5r3/YZ+pbbxXichqovUW7/YtvSNtLVd/t+NjGW/Bnivlx+Y70bohzEke3cVs53ZGkSRpYAAAAAoAEXSwAAAADQgIslAAAAAGgw9ZylPP8kzlO64ooravviimgPPvjgxP9niDlLMd+yLQ90FuawxPzLnPsfV0fJK/f1be/evbXXu3btGm8fPXq0ti+e+5JzlKT23PCYW9s292eIvNuYA5zzbmOueNsKdEPkY7fNS8qr95TWNn8irprTNu9hCLHdaVsJsW1VtC7ysbO8+lV05MiR2usY99DHr21uVZxrI9XPe/58Q8+tiu6+++6J+9qO9VDazm+s90PPWcrlPp7D3N7Hc5/n3wzddsXj1LZSZz7X8fPOZQXXaeXvQPF45n3xM5SYCxTFOPNc2rx649Byfx6P41zmMw0t1ok8r7dttc6h26e24xnne+eV8uYTJyNLAAAAANCAiyUAAAAAaDB1Gl5cDlySjh07Nt6OqVdZ/L0h5CHDOPx2ySWXTPx3faSyzFXbsotx39DLmufzG1Mw87LxMfUul5m2h9L2IaZi5GHtaR/4mYdv+xDPdX7YXyyzORVi6PSHmDKS00RLp7HOZWnttgfnxfSWIZaQju9x/Pjx2r5YZodOd2j7/3MKUNsDa/vWVgfazl/JuiPV25ycqlNabitj3c7HrXS60CRtaWo5XafvupTbm5gelNPEYiz5+8rQbWzbEvYxltJloC2FtnR6eP6uFh//kR8ZEMtsjjse4yFS3eL5zXHG9Oa2h+UOIbajuc+Pxyl/x5tPGisjSwAAAADQgIslAAAAAGjAxRIAAAAANJj3nKVpl4Ueet5Kzk2Meett7z10rn3Te8bc5racytJ5onEO0+LFi2v74lLxedn4tWvXjrf7KAc5t3b37t3j7byMcLR58+ba640bN3Yb2FnEc53zbmM+dvw82RDL9sbymvOmY9nN+dZD5Fjn95j2kQG5ng09d7Gt3YnL4C5durS2r+9jmuemxDlque7Gpa/z/IG+c+1LLvk9F3leR3wd5zFI9fkhJeYC5vPUNm8zxprLcsklpfM8oXgcu5i/MBdtc3pyX942nzLPaexaPi5t9XXfvn29xjIXbW3oiRMnxtt5nm18BMtQ88CnPYdxXptUL0NDPJ4h9uH5uMXv1KWXjY/v33Zc8neStmXxJ2FkCQAAAAAacLEEAAAAAA2mTsPLqRcrVqyY+Lsx9S7/3sGDB6d9y0HldI4h0h9yekNOB4tiqkDpoc8ol4uYardnz57avm3bto23t2/f3nks+cns8XUe1p7V5UbnkgY29JKtMS0jPyk9pkLkdMErr7xyvN1XvcopI7G+uPvEfUOn3eVyd/jw4fH2jh07avvi+c1lMn6GIdIcY9xzaStjak8f6U5t752XXo5lNH+Gvh8RkOOMKSM5fTie67mkafUllq8cT2xjZ6lfakvLbTsXfbQHuWzFcpjrRCyzOT2z71SxfP5iulUuo1GJ705RW92NKcJt+/o6tjlFMLbxOW0stvf5sQclv5PkOGMsQ6QEdiGnjcYyM22/xMgSAAAAADTgYgkAAAAAGnCxBAAAAAANpp6zFJeLlupzkdavX1/bl19Hu3btmvYt3/JyXmrM/1y2bFltX8wTzfm1MR9ziCUwd+7cOd7OS8jH+WpXX311bV/f89VyvnnbHIX4u3lZ8aFz72M+eJ531ZaPPXQecyyveV5SnCOQ51LFnOCh8tnjXJl8TI8cOTJIDE3yXIoYW166Nx7HvORsnDvS93ybLJ/DGHee09L3ssy5rsZzu3Xr1tq+6667buK/K7FE9xm5fEazMA8ont88r7at/MbYc1/XhTwfJM6jzI8siXO98ly2vud+5nMY60j+DHEOcOnHGMTjmfvIuJx0yboj1b83tfWJuXzGstxXGcjnPrbVuU7E4z90m57Pfeyn8r6h50q3afuOF+VHb8TvXNN+HkaWAAAAAKABF0sAAAAA0GDeaXhx6eeYliVJx44dG2+vXLlyvrF1Ig6D5hS1OBSXl0DsI20gy8PXbcvzxmHZvIxnHDIdIg0vpgrk5cGjnHa3adOm3mI6mzwcHlMxhjjXbdasWTPebls+PqdCDJ2mEY9THrqOqSU5rhLLnsb6nJeNL5nalN87Hqu8DH9Mccr1OqeU9C2+X26bYipEbkeHTtGJaX/5GMW4c7pgSfkYxRSnnI4dj/VQ5bit3sfYc8pljK+PtiqnB+W0y0lyXSrZ/ucyGuv80HHleh37m5y62Hd67VzEMpiPZ/zelPvWWA6GePxClsvv0P15lNuSGFuOq+07ytBiO56nBkSxTZXq537adpSRJQAAAABowMUSAAAAADTgYgkAAAAAGpi7l44BAAAAAGYOI0sAAAAA0ICLJQAAAABowMUSAAAAADTgYgkAAAAAGnCxBAAAAAANuFgCAAAAgAb/BwPA+a2bf/nBAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(15, 0.5))\n", "for index, (image, label) in enumerate(zip(digits.images[0:20], digits.target[0:20])):\n", " ax = fig.add_subplot(1, 20, index+1)\n", " ax.imshow(image, cmap=plt.cm.gray)\n", " ax.set_title(label)\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training/Test set\n", "\n", "Before training our model, the [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) function will separate our dataset into a training set and a test set. The samples from the test set are never used during the training phase. This allows for a fair evaluation of the model's performance." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_img, test_img, train_lbl, test_lbl = train_test_split(\n", " digits.data, digits.target, test_size=1/6) # keep ~300 images as test set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check that all classes are well balanced in the training and test sets." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([151, 152, 143, 157, 151, 149, 155, 155, 137, 147]),\n", " array([0. , 0.9, 1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9. ]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.histogram(train_lbl, bins=10)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([27, 30, 34, 26, 30, 33, 26, 24, 37, 33]),\n", " array([0. , 0.9, 1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9. ]))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.histogram(test_lbl, bins=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Supervised learning: logistic regression\n", "\n", "### Linear regression reminder\n", "\n", "Linear regression is used to predict an dependent value $y$ from an n-dimensional vector $x$.\n", "The assumption made here is that the output depends linearly on the input components, i.e. $y = mx + b$.\n", "\n", "Given a set of input and output values, the goal is to compute $m$ and $b$ minimizing the [mean squared error (MSE)](https://en.wikipedia.org/wiki/Mean_squared_error) between the predicted and actual outputs.\n", "In scikit-learn this method is available through [`LinearRegression`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html).\n", "\n", "### Logistic regression\n", "\n", "Logistic regression is used to predict categorical data (e.g. yes/no, member/non-member, ham/spam, benign/malignant, ...).\n", "It uses the output of a linear predictor, and maps it to a probability using a [sigmoid function](https://en.wikipedia.org/wiki/Sigmoid_function), such as the logistic function $s(z) = \\frac{1}{1+e^{-z}}$. \n", "The output is a probability score between 0 and 1, and using a simple thresholding the class output will be positive if the probability is greater than 0.5, negative if not.\n", "A [log-loss cost function](http://wiki.fast.ai/index.php/Logistic_Regression#Cost_Function) (not just the MSE as for linear regression) is used to train logistic regression (using gradient descent for instance).\n", "\n", "[Multinomial logistic regression](https://en.wikipedia.org/wiki/Multinomial_logistic_regression) is an extension of the binary classification problem to a $n$-classes problem." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now create a logistic regression object and fit the parameters using the training data.\n", "\n", "NB: as the dataset is quite simple, default parameters will give good results. Check the [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) for fine-tuning possibilities." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "\n", "# All unspecified parameters are left to their default values.\n", "logisticRegr = LogisticRegression(verbose=1, solver='liblinear', multi_class='auto') # set solver and multi_class to silence warnings" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[LibLinear]" ] }, { "data": { "text/plain": [ "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, l1_ratio=None, max_iter=100,\n", " multi_class='auto', n_jobs=None, penalty='l2',\n", " random_state=None, solver='liblinear', tol=0.0001, verbose=1,\n", " warm_start=False)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logisticRegr.fit(train_img, train_lbl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model performance evaluation\n", "\n", "For a binary classification problem, let us denote by $TP$, $TN$, $FP$, and $FN$ the number of true positives, true negatives, false positives and false negatives.\n", "\n", "### Accuracy\n", "\n", "The *accuracy* is defined by $a = \\frac{TP}{TP + TN + FP + FN}$\n", "\n", "NB: in scikit-learn, models may have different definitions of the `score` method. For multi-class logistic regression, the value is the mean accuracy for each class." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy = 0.9600\n" ] } ], "source": [ "score = logisticRegr.score(test_img, test_lbl)\n", "print(f'accuracy = {score:.4f}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### F1 score\n", "\n", "Accuracy only provides partial information about the performance of a model. Many other [metrics](https://scikit-learn.org/stable/modules/model_evaluation.html#classification-metrics) are part of scikit-learn.\n", "\n", "A metric that provides a more complete overview of the classification performance is the [F1 score](https://en.wikipedia.org/wiki/F1_score). It takes into account not only the valid predictions but also the incorrect ones, by combining precision and recall.\n", "\n", "*Precision* is the number of positive predictions divided by the total number of positive class values predicted, i.e. $p=\\frac{TP}{TP+FP}$. A low precision indicates a high number of false positives.\n", "\n", "*Recall* is the number of positive predictions divided by the number of positive class values in the test data, i.e. $r=\\frac{TP}{TP+FN}$. A low recall indicates a high number of false negatives.\n", "\n", "Finally the F1 score is the harmonic mean between precision and recall, i.e. $F1=2\\frac{p.r}{p+r}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us compute the predicted labels in the test set:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "pred_lbl = logisticRegr.predict(test_img)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score, classification_report\n", "from sklearn.utils.multiclass import unique_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [`f1_score`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) function computes the F1 score. The `average` parameter controls whether the result is computed globally over all classes (`average='micro'`) or if the F1 score is computed for each class then averaged (`average='macro'`)." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.96" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f1_score(test_lbl, pred_lbl, average='micro')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9594503798365037" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f1_score(test_lbl, pred_lbl, average='macro')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`classification_report` provides a synthetic overview of all results for each class, as well as globally." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 27\n", " 1 0.91 0.97 0.94 30\n", " 2 1.00 1.00 1.00 34\n", " 3 0.95 0.81 0.88 26\n", " 4 1.00 0.97 0.98 30\n", " 5 1.00 1.00 1.00 33\n", " 6 1.00 0.96 0.98 26\n", " 7 0.92 1.00 0.96 24\n", " 8 0.88 0.97 0.92 37\n", " 9 0.97 0.91 0.94 33\n", "\n", " accuracy 0.96 300\n", " macro avg 0.96 0.96 0.96 300\n", "weighted avg 0.96 0.96 0.96 300\n", "\n" ] } ], "source": [ "print(classification_report(test_lbl, pred_lbl))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Confusion matrix\n", "\n", "In the case of a multi-class problem, the *confusion matrix* is often used to present the results." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "def plot_confusion_matrix(y_true, y_pred, classes,\n", " normalize=False,\n", " title=None,\n", " cmap=plt.cm.Blues):\n", " \"\"\"\n", " This function prints and plots the confusion matrix.\n", " Normalization can be applied by setting `normalize=True`.\n", " \"\"\"\n", " if not title:\n", " if normalize:\n", " title = 'Normalized confusion matrix'\n", " else:\n", " title = 'Confusion matrix, without normalization'\n", "\n", " # Compute confusion matrix\n", " cm = confusion_matrix(y_true, y_pred)\n", " # Only use the labels that appear in the data\n", " classes = classes[unique_labels(y_true, y_pred)]\n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", "\n", " print(cm)\n", "\n", " fig, ax = plt.subplots()\n", " im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n", " ax.figure.colorbar(im, ax=ax)\n", " # We want to show all ticks...\n", " ax.set(xticks=np.arange(cm.shape[1]),\n", " yticks=np.arange(cm.shape[0]),\n", " # ... and label them with the respective list entries\n", " xticklabels=classes, yticklabels=classes,\n", " title=title,\n", " ylabel='True label',\n", " xlabel='Predicted label')\n", "\n", " # Rotate the tick labels and set their alignment.\n", " plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n", " rotation_mode=\"anchor\")\n", "\n", " # Loop over data dimensions and create text annotations.\n", " fmt = '.2f' if normalize else 'd'\n", " thresh = cm.max() / 2.\n", " for i in range(cm.shape[0]):\n", " for j in range(cm.shape[1]):\n", " ax.text(j, i, format(cm[i, j], fmt),\n", " ha=\"center\", va=\"center\",\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", " fig.tight_layout()\n", " return ax" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion matrix, without normalization\n", "[[27 0 0 0 0 0 0 0 0 0]\n", " [ 0 29 0 1 0 0 0 0 0 0]\n", " [ 0 0 34 0 0 0 0 0 0 0]\n", " [ 0 0 0 21 0 0 0 0 4 1]\n", " [ 0 1 0 0 29 0 0 0 0 0]\n", " [ 0 0 0 0 0 33 0 0 0 0]\n", " [ 0 1 0 0 0 0 25 0 0 0]\n", " [ 0 0 0 0 0 0 0 24 0 0]\n", " [ 0 1 0 0 0 0 0 0 36 0]\n", " [ 0 0 0 0 0 0 0 2 1 30]]\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_confusion_matrix(test_lbl, pred_lbl, np.array(list(map(lambda x: str(x), range(10)))), normalize=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Supervised learning: support-vector machines\n", "\n", "[Support-vector machines (SVM)](https://en.wikipedia.org/wiki/Support-vector_machine) are also used for classification tasks.\n", "For a binary classification task of $n$-dimensional feature vectors, a linear SVM try to return the ($n-1$)-dimensional hyperplane that separate the two classes with the largest possible margin.\n", "Nonlinear SVMs fit the maximum-margin hyperplane in a transformed feature space.\n", "Although the classifier is a hyperplane in the transformed feature space, it may be nonlinear in the original input space. \n", "\n", "The goal here is to show that a method (e.g. the previously used logistic regression) can be substituted transparently for another one." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from sklearn import svm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Default parameters perform well on this dataset.\n", "It might be needed to adjust $C$ and $\\gamma$ (e.g. via a grid search) for optimal performance (cf. [SVC documentation](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC))." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "clf = svm.SVC(gamma='scale') # default kernel is RBF" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(train_img, train_lbl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The classification accuracy improves with respect to logistic regression (here `score` also computes mean accuracy, as in logistic regression)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9933333333333333" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.score(test_img, test_lbl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The F1 score is also improved." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 27\n", " 1 0.97 1.00 0.98 30\n", " 2 0.97 1.00 0.99 34\n", " 3 1.00 0.96 0.98 26\n", " 4 1.00 1.00 1.00 30\n", " 5 1.00 1.00 1.00 33\n", " 6 1.00 1.00 1.00 26\n", " 7 1.00 1.00 1.00 24\n", " 8 1.00 0.97 0.99 37\n", " 9 1.00 1.00 1.00 33\n", "\n", " accuracy 0.99 300\n", " macro avg 0.99 0.99 0.99 300\n", "weighted avg 0.99 0.99 0.99 300\n", "\n" ] } ], "source": [ "pred_lbl_svm = clf.predict(test_img)\n", "print(classification_report(test_lbl, pred_lbl_svm))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unsupervised learning: $k$-means\n", "\n", "[$k$-means](https://en.wikipedia.org/wiki/K-means_clustering) aims at partitioning a samples into $k$ clusters, s.t. each sample belongs to the cluster having the closest mean. Its implementation is iterative, and relies on a prior knowledge of the number of clusters present. \n", "\n", "One important step in $k$-means clustering is the initialization, i.e. the choice of initial clusters to be refined.\n", "This choice can have a significant impact on results." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "from sklearn.cluster import KMeans" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "kmeans = KMeans(n_clusters=10)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,\n", " n_clusters=10, n_init=10, n_jobs=None, precompute_distances='auto',\n", " random_state=None, tol=0.0001, verbose=0)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kmeans.fit(digits.data)\n", "km_labels = kmeans.predict(digits.data)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 1, 2, ..., 8, 9, 8])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "digits.target" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([9, 7, 7, ..., 7, 4, 4], dtype=int32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "km_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we have ground truth information of classes, we can check if the $k$-means results make sense.\n", "However as you can see, the labels produced by $k$-means and the ground truth ones do not match.\n", "An agreement score based on [mutual information](https://scikit-learn.org/stable/modules/clustering.html#clustering-evaluation), insensitive to labels permutation can be used to evaluate the results. " ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import adjusted_mutual_info_score" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/michael/.conda/envs/ntds_2019/lib/python3.7/site-packages/sklearn/metrics/cluster/supervised.py:746: FutureWarning: The behavior of AMI will change in version 0.22. To match the behavior of 'v_measure_score', AMI will use average_method='arithmetic' by default.\n", " FutureWarning)\n" ] }, { "data": { "text/plain": [ "0.7366155637042371" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adjusted_mutual_info_score(digits.target, kmeans.labels_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Unsupervized learning: dimensionality reduction\n", "\n", "You can also try to visualize the clusters as in this [scikit-learn demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html). Mapping the input features to lower dimensional embeddings (2D or 3D), e.g. using PCA otr tSNE is required for visualization. [This demo](https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html) provides an overview of the possibilities." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from matplotlib import offsetbox\n", "\n", "def plot_embedding(X, y, title=None):\n", " \"\"\"Scale and visualize the embedding vectors.\"\"\"\n", " x_min, x_max = np.min(X, 0), np.max(X, 0)\n", " X = (X - x_min) / (x_max - x_min)\n", "\n", " plt.figure()\n", " ax = plt.subplot(111)\n", " for i in range(X.shape[0]):\n", " plt.text(X[i, 0], X[i, 1], str(y[i]),\n", " color=plt.cm.Set1(y[i] / 10.),\n", " fontdict={'weight': 'bold', 'size': 9})\n", "\n", " if hasattr(offsetbox, 'AnnotationBbox'):\n", " # only print thumbnails with matplotlib > 1.0\n", " shown_images = np.array([[1., 1.]]) # just something big\n", " for i in range(X.shape[0]):\n", " dist = np.sum((X[i] - shown_images) ** 2, 1)\n", " if np.min(dist) < 4e-3:\n", " # don't show points that are too close\n", " continue\n", " shown_images = np.r_[shown_images, [X[i]]]\n", " imagebox = offsetbox.AnnotationBbox(\n", " offsetbox.OffsetImage(digits.images[i], cmap=plt.cm.gray_r),\n", " X[i])\n", " ax.add_artist(imagebox)\n", " plt.xticks([]), plt.yticks([])\n", " if title is not None:\n", " plt.title(title)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "from sklearn import manifold" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)\n", "X_tsne = tsne.fit_transform(digits.data)\n", "\n", "plot_embedding(X_tsne, digits.target,\n", " \"t-SNE embedding of the digits (ground truth labels)\")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_embedding(X_tsne, km_labels,\n", " \"t-SNE embedding of the digits (kmeans labels)\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }