{ "cells": [ { "cell_type": "markdown", "id": "ultimate-understanding", "metadata": {}, "source": [ "## Classification with extra-trees classifier" ] }, { "cell_type": "code", "execution_count": 1, "id": "crude-shape", "metadata": { "tags": [] }, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd\n", "import numpy as np\n", "\n", "from sklearn import ensemble, datasets, metrics, model_selection, preprocessing, pipeline" ] }, { "cell_type": "markdown", "id": "circular-dylan", "metadata": {}, "source": [ "Load the data set" ] }, { "cell_type": "code", "execution_count": 2, "id": "turkish-hampton", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".. _wine_dataset:\n", "\n", "Wine recognition dataset\n", "------------------------\n", "\n", "**Data Set Characteristics:**\n", "\n", " :Number of Instances: 178 (50 in each of three classes)\n", " :Number of Attributes: 13 numeric, predictive attributes and the class\n", " :Attribute Information:\n", " \t\t- Alcohol\n", " \t\t- Malic acid\n", " \t\t- Ash\n", "\t\t- Alcalinity of ash \n", " \t\t- Magnesium\n", "\t\t- Total phenols\n", " \t\t- Flavanoids\n", " \t\t- Nonflavanoid phenols\n", " \t\t- Proanthocyanins\n", "\t\t- Color intensity\n", " \t\t- Hue\n", " \t\t- OD280/OD315 of diluted wines\n", " \t\t- Proline\n", "\n", " - class:\n", " - class_0\n", " - class_1\n", " - class_2\n", "\t\t\n", " :Summary Statistics:\n", " \n", " ============================= ==== ===== ======= =====\n", " Min Max Mean SD\n", " ============================= ==== ===== ======= =====\n", " Alcohol: 11.0 14.8 13.0 0.8\n", " Malic Acid: 0.74 5.80 2.34 1.12\n", " Ash: 1.36 3.23 2.36 0.27\n", " Alcalinity of Ash: 10.6 30.0 19.5 3.3\n", " Magnesium: 70.0 162.0 99.7 14.3\n", " Total Phenols: 0.98 3.88 2.29 0.63\n", " Flavanoids: 0.34 5.08 2.03 1.00\n", " Nonflavanoid Phenols: 0.13 0.66 0.36 0.12\n", " Proanthocyanins: 0.41 3.58 1.59 0.57\n", " Colour Intensity: 1.3 13.0 5.1 2.3\n", " Hue: 0.48 1.71 0.96 0.23\n", " OD280/OD315 of diluted wines: 1.27 4.00 2.61 0.71\n", " Proline: 278 1680 746 315\n", " ============================= ==== ===== ======= =====\n", "\n", " :Missing Attribute Values: None\n", " :Class Distribution: class_0 (59), class_1 (71), class_2 (48)\n", " :Creator: R.A. Fisher\n", " :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n", " :Date: July, 1988\n", "\n", "This is a copy of UCI ML Wine recognition datasets.\n", "https://archive.ics.uci.edu/ml/machine-learning-databases/wine/wine.data\n", "\n", "The data is the results of a chemical analysis of wines grown in the same\n", "region in Italy by three different cultivators. There are thirteen different\n", "measurements taken for different constituents found in the three types of\n", "wine.\n", "\n", "Original Owners: \n", "\n", "Forina, M. et al, PARVUS - \n", "An Extendible Package for Data Exploration, Classification and Correlation. \n", "Institute of Pharmaceutical and Food Analysis and Technologies,\n", "Via Brigata Salerno, 16147 Genoa, Italy.\n", "\n", "Citation:\n", "\n", "Lichman, M. (2013). UCI Machine Learning Repository\n", "[https://archive.ics.uci.edu/ml]. Irvine, CA: University of California,\n", "School of Information and Computer Science. \n", "\n", ".. topic:: References\n", "\n", " (1) S. Aeberhard, D. Coomans and O. de Vel, \n", " Comparison of Classifiers in High Dimensional Settings, \n", " Tech. Rep. no. 92-02, (1992), Dept. of Computer Science and Dept. of \n", " Mathematics and Statistics, James Cook University of North Queensland. \n", " (Also submitted to Technometrics). \n", "\n", " The data was used with many others for comparing various \n", " classifiers. The classes are separable, though only RDA \n", " has achieved 100% correct classification. \n", " (RDA : 100%, QDA 99.4%, LDA 98.9%, 1NN 96.1% (z-transformed data)) \n", " (All results using the leave-one-out technique) \n", "\n", " (2) S. Aeberhard, D. Coomans and O. de Vel, \n", " \"THE CLASSIFICATION PERFORMANCE OF RDA\" \n", " Tech. Rep. no. 92-01, (1992), Dept. of Computer Science and Dept. of \n", " Mathematics and Statistics, James Cook University of North Queensland. \n", " (Also submitted to Journal of Chemometrics).\n", "\n" ] } ], "source": [ "wine = datasets.load_wine()\n", "print(wine.DESCR)" ] }, { "cell_type": "code", "execution_count": 3, "id": "advised-pride", "metadata": { "tags": [] }, "outputs": [], "source": [ "X = pd.DataFrame(wine.data, columns=wine.feature_names)\n", "y = wine.target" ] }, { "cell_type": "markdown", "id": "turkish-proceeding", "metadata": {}, "source": [ "Stratify the data by the target label" ] }, { "cell_type": "code", "execution_count": 4, "id": "emerging-emphasis", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train samples: 89\n", "test samples 89\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVdklEQVR4nO3dfZBV9Z3n8fdX6LEhMkCgJShBGMZypNS0sWWSMTXxIUQ0M0El6pr4QNbddtaHMtkJpZlKnJjKbNwyMepuYoopEV0lRiVKJtEd1CVrfIrTMIyiENGMjq1GOrAoqDhCvvvHvWCLPNzGPvfQnPer6laf8zvn3PO91frh9O/87u9EZiJJqo69yi5AktRcBr8kVYzBL0kVY/BLUsUY/JJUMYPLLqARo0ePzgkTJpRdhiQNKIsXL/5dZrZt3T4ggn/ChAl0dXWVXYYkDSgR8fy22u3qkaSKMfglqWIMfkmqmAHRxy9JffX222/T3d3Nhg0byi6lcK2trYwbN46WlpaG9jf4Je2Ruru7GTZsGBMmTCAiyi6nMJnJ6tWr6e7uZuLEiQ0dY1ePpD3Shg0bGDVq1B4d+gARwahRo/r0l43BL2mPtaeH/mZ9/ZwGvyRVjMEvSTuxdu1afvCDH5RdRr/x5q70PvzbNw8tu4Q+G3/ZE2WXMOBsDv7zzz+/7FL6hVf8krQTl156Kc8++yzt7e2ceuqp3HXXXVu2feELX2DBggXMnTuX6dOnc/TRR3PggQdy+eWXb9nn5ptvZsqUKbS3t3PeeeexadOmEj7FOwoL/ohojYjHIuJfIuLJiLi83j43Iv41IpbWX+1F1SBJ/eGKK65g0qRJLF26lAsvvJC5c+cC8Oqrr/Lwww/zmc98BoDHHnuM+fPn8/jjj3P77bfT1dXF8uXL+fGPf8xDDz3E0qVLGTRoELfcckuJn6bYrp63gGMzc31EtAAPRsQ99W2zMvOOAs8tSYX45Cc/yfnnn09PTw/z589nxowZDB5ci9KpU6cyatQoAE455RQefPBBBg8ezOLFiznyyCMBePPNN9l3331Lqx8KDP6sPcV9fX21pf7yye6SBryzzz6bm2++mVtvvZUbbrhhS/vWwyojgszknHPO4dvf/nazy9yuQm/uRsQgYDHwx8D3M/NXEfFfgL+LiMuA+4FLM/OtbRzbCXQCjB8/vsgytRs5YtZNZZfQJ3cOK7sCNcOwYcNYt27dlvWZM2cyZcoUPvShDzF58uQt7ffeey9r1qxhyJAh3HXXXcyZM4ehQ4cyffp0vvzlL7PvvvuyZs0a1q1bxwEHHFDGRwEKvrmbmZsysx0YB0yJiEOArwJ/AhwJfBC4ZDvHzs7MjszsaGt7z3MEJKlpRo0axVFHHcUhhxzCrFmzGDNmDAcffDBf/OIX37XflClTmDFjBocddhgzZsygo6ODyZMn861vfYtPf/rTHHbYYUydOpWXX365pE9S05ThnJm5NiIWAdMy8zv15rci4gbgK82oQZLej3nz5m1ZfuONN1i5ciVnnHHGu/YZN27cu0b8bHb66adz+umnF11iw4oc1dMWESPqy0OAqcCKiBhbbwvgJGBZUTVIUn+77777OPjgg7nooosYPnx42eXskiKv+McCN9b7+fcCbsvMn0XE/4mINiCApcBfFViDJPWrT33qUzz//HufaDhz5kxmzpzZ/IJ2QZGjeh4HDt9G+7FFnVOStHN+c1eSKsbgl6SKMfglqWKcnVNSJfT3lwMXX3n2TvdZu3Yt8+bN6/OsnieeeCLz5s1jxIgRu1jdjnnFL0kF2d48/hs3btzhcXfffXdhoQ9e8UtSYXpP59zS0kJraysjR45kxYoVPP3005x00km88MILbNiwgYsvvpjOzk4AJkyYQFdXF+vXr+eEE07gE5/4BA8//DD7778/CxYsYMiQIe+rLq/4JakgvadzvvLKK1myZAnXXHMNTz/9NABz5sxh8eLFdHV1ce2117J69er3vMfKlSu54IILePLJJxkxYgTz589/33V5xS9JTTJlyhQmTpy4Zf3aa6/lzjvvBOCFF15g5cqVW6Z13mzixIm0t7cDcMQRR/Dcc8+97zoMfklqkg984ANbln/xi19w33338cgjjzB06FCOPvpoNmzY8J5j9t577y3LgwYN4s0333zfddjVI0kF2Xo6595effVVRo4cydChQ1mxYgWPPvpo0+ryil9SJTQy/LK/9Z7OeciQIYwZM2bLtmnTpvHDH/6Qgw8+mIMOOoiPfexjTavL4JekAvWezrm3vffem3vuuWeb2zb3448ePZply96ZwPgrX+mfWezt6pGkijH4JaliDH5JqhiDX5IqxuCXpIox+CWpYhzOKakS/u2bh/br+42/7Imd7rOr0zIDXH311XR2djJ06NBdKW+HvOKXpIJsb1rmRlx99dW88cYb/VxRTWFX/BHRCjwA7F0/zx2Z+bcRMRG4FRgFLAbOysx/L6oOSSpL72mZp06dyr777sttt93GW2+9xcknn8zll1/O66+/zmmnnUZ3dzebNm3i61//Oq+88govvfQSxxxzDKNHj2bRokX9WleRXT1vAcdm5vqIaAEejIh7gP8KfC8zb42IHwLnAtcVWIckleKKK65g2bJlLF26lIULF3LHHXfw2GOPkZl89rOf5YEHHqCnp4f99tuPn//850BtDp/hw4dz1VVXsWjRIkaPHt3vdRXW1ZM16+urLfVXAscCd9TbbwROKqoGSdpdLFy4kIULF3L44Yfz0Y9+lBUrVrBy5UoOPfRQ7r33Xi655BJ++ctfMnz48MJrKfTmbkQMotad88fA94FngbWZufm5Y93A/ts5thPoBBg/fnzD5+zv52o2QxmTR0lqrszkq1/9Kuedd957ti1ZsoS7776br33taxx33HFcdtllhdZS6M3dzNyUme3AOGAK8Cd9OHZ2ZnZkZkdbW1tRJUpSYXpPy3z88cczZ84c1q+vdYS8+OKLrFq1ipdeeomhQ4dy5plnMmvWLJYsWfKeY/tbU4ZzZubaiFgEfBwYERGD61f944AXm1GDpGprZPhlf+s9LfMJJ5zA5z//eT7+8Y8DsM8++3DzzTfzzDPPMGvWLPbaay9aWlq47rraLc/Ozk6mTZvGfvvtN3Bu7kZEG/B2PfSHAFOB/w4sAj5HbWTPOcCComqQpLJtPS3zxRdf/K71SZMmcfzxx7/nuIsuuoiLLrqokJqKvOIfC9xY7+ffC7gtM38WEU8Bt0bEt4B/Bq4vsAZJ0lYKC/7MfBw4fBvtv6HW3y9JKoFTNuwG+vur5EUro69U2hWZSUSUXUbhMrNP+ztlg6Q9UmtrK6tXr+5zKA40mcnq1atpbW1t+Biv+CXtkcaNG0d3dzc9PT1ll1K41tZWxo0b1/D+Br+kPVJLSwsTJ04su4zdkl09klQxBr8kVYzBL0kVY/BLUsUY/JJUMQa/JFWMwS9JFWPwS1LFGPySVDEGvyRVjMEvSRVj8EtSxRj8klQxBr8kVYzBL0kVY/BLUsUUFvwR8eGIWBQRT0XEkxFxcb39GxHxYkQsrb9OLKoGSdJ7FfkEro3AX2fmkogYBiyOiHvr276Xmd8p8NySpO0oLPgz82Xg5fryuohYDuxf1PkkSY1pSh9/REwADgd+VW+6MCIej4g5ETFyO8d0RkRXRHRV4WHJktQshQd/ROwDzAe+lJmvAdcBk4B2an8RfHdbx2Xm7MzsyMyOtra2osuUpMooNPgjooVa6N+SmT8ByMxXMnNTZv4e+HtgSpE1SJLerchRPQFcDyzPzKt6tY/ttdvJwLKiapAkvVeRo3qOAs4CnoiIpfW2vwHOiIh2IIHngPMKrEGStJUiR/U8CMQ2Nt1d1DklSTtX5BW/pAo5YtZNZZfQZ3cOu7LsEvps/GVPvO/3cMoGSaoYg1+SKsbgl6SKMfglqWIMfkmqGINfkirG4JekijH4JaliDH5JqhiDX5IqxuCXpIox+CWpYgx+SaoYg1+SKsbgl6SKMfglqWIMfkmqmIaCPyLub6RNkrT72+GjFyOiFRgKjI6IkbzzDN0/BPYvuDZJUgF29szd84AvAfsBi3kn+F8D/ueODoyIDwM3AWOABGZn5jUR8UHgx8AE4DngtMz8f7tWviSpr3bY1ZOZ12TmROArmflHmTmx/vpIZu4w+IGNwF9n5mTgY8AFETEZuBS4PzMPBO6vr0uSmmRnV/wAZOb/iIg/o3aVPrhX+007OOZl4OX68rqIWE6te2g6cHR9txuBXwCX9L10SdKuaCj4I+J/AZOApcCmenNS68pp5PgJwOHAr4Ax9X8UAH5LrStoW8d0Ap0A48ePb+Q0kqQGNBT8QAcwOTOzryeIiH2A+cCXMvO1iNiyLTMzIrb5npk5G5gN0NHR0efzSpK2rdFx/MuAD/X1zSOihVro35KZP6k3vxIRY+vbxwKr+vq+kqRd1+gV/2jgqYh4DHhrc2NmfnZ7B0Tt0v56YHlmXtVr00+Bc4Ar6j8X9LVoSdKuazT4v7EL730UcBbwREQsrbf9DbXAvy0izgWeB07bhfeWJO2iRkf1/N++vnFmPsg74/63dlxf30+S1D8aHdWzjtooHoA/AFqA1zPzD4sqTJJUjEav+IdtXq733U+n9qUsSdIA0+fZObPmLuD4/i9HklS0Rrt6Tum1uhe1cf0bCqlIklSoRkf1/GWv5Y3UJleb3u/VSJIK12gf/xeLLkSS1ByNPohlXETcGRGr6q/5ETGu6OIkSf2v0Zu7N1D7xu1+9dc/1NskSQNMo8Hflpk3ZObG+msu0FZgXZKkgjQa/Ksj4syIGFR/nQmsLrIwSVIxGg3+/0htTp3fUnu4yueAmQXVJEkqUKPDOb8JnLP52bj15+Z+h9o/CJKkAaTRK/7Dej8QPTPXUHuiliRpgGk0+PeKiJGbV+pX/I3+tSBJ2o00Gt7fBR6JiNvr66cCf1dMSZKkIjX6zd2bIqILOLbedEpmPlVcWZKkojTcXVMPesNekga4Pk/LLEka2Ax+SaqYwoI/IubUJ3Rb1qvtGxHxYkQsrb9OLOr8kqRtK/KKfy4wbRvt38vM9vrr7gLPL0nahsKCPzMfANYU9f6SpF1TRh//hRHxeL0raOTOd5ck9admB/91wCSgndpkb9/d3o4R0RkRXRHR1dPT06TyJGnP19Tgz8xXMnNTZv4e+Htgyg72nZ2ZHZnZ0dbm1P+S1F+aGvwRMbbX6snAsu3tK0kqRmETrUXEj4CjgdER0Q38LXB0RLQDCTwHnFfU+SVJ21ZY8GfmGdtovr6o80mSGuM3dyWpYgx+SaoYg1+SKsbgl6SKMfglqWIMfkmqGINfkirG4JekijH4JaliDH5JqhiDX5IqxuCXpIox+CWpYgx+SaoYg1+SKsbgl6SKMfglqWIMfkmqGINfkirG4Jekiiks+CNiTkSsiohlvdo+GBH3RsTK+s+RRZ1fkrRtRV7xzwWmbdV2KXB/Zh4I3F9flyQ1UWHBn5kPAGu2ap4O3FhfvhE4qajzS5K2rdl9/GMy8+X68m+BMdvbMSI6I6IrIrp6enqaU50kVUBpN3czM4HcwfbZmdmRmR1tbW1NrEyS9mzNDv5XImIsQP3nqiafX5Iqr9nB/1PgnPryOcCCJp9fkiqvyOGcPwIeAQ6KiO6IOBe4ApgaESuBT9XXJUlNNLioN87MM7az6biizilJ2jm/uStJFWPwS1LFGPySVDEGvyRVjMEvSRVj8EtSxRj8klQxBr8kVYzBL0kVY/BLUsUY/JJUMQa/JFWMwS9JFWPwS1LFGPySVDEGvyRVjMEvSRVj8EtSxRj8klQxBr8kVUxhD1vfkYh4DlgHbAI2ZmZHGXVIUhWVEvx1x2Tm70o8vyRVkl09klQxZQV/AgsjYnFEdG5rh4jojIiuiOjq6elpcnmStOcqK/g/kZkfBU4ALoiIP996h8ycnZkdmdnR1tbW/AolaQ9VSvBn5ov1n6uAO4EpZdQhSVXU9OCPiA9ExLDNy8CngWXNrkOSqqqMUT1jgDsjYvP552Xm/y6hDkmqpKYHf2b+BvhIs88rSapxOKckVYzBL0kVY/BLUsUY/JJUMQa/JFWMwS9JFWPwS1LFGPySVDEGvyRVjMEvSRVj8EtSxRj8klQxBr8kVYzBL0kVY/BLUsUY/JJUMQa/JFWMwS9JFWPwS1LFGPySVDGlBH9ETIuIX0fEMxFxaRk1SFJVNT34I2IQ8H3gBGAycEZETG52HZJUVWVc8U8BnsnM32TmvwO3AtNLqEOSKikys7knjPgcMC0z/1N9/SzgTzPzwq326wQ666sHAb9uaqHNNRr4XdlFaJf4uxvY9vTf3wGZ2bZ14+AyKmlEZs4GZpddRzNERFdmdpRdh/rO393AVtXfXxldPS8CH+61Pq7eJklqgjKC/5+AAyNiYkT8AfAfgJ+WUIckVVLTu3oyc2NEXAj8IzAImJOZTza7jt1MJbq09lD+7ga2Sv7+mn5zV5JULr+5K0kVY/BLUsUY/CVy6oqBKyLmRMSqiFhWdi3qu4j4cEQsioinIuLJiLi47JqayT7+ktSnrngamAp0UxvtdEZmPlVqYWpIRPw5sB64KTMPKbse9U1EjAXGZuaSiBgGLAZOqsr/f17xl8epKwawzHwAWFN2Hdo1mflyZi6pL68DlgP7l1tV8xj85dkfeKHXejcV+g9P2l1ExATgcOBXJZfSNAa/pMqKiH2A+cCXMvO1sutpFoO/PE5dIZUoIlqohf4tmfmTsutpJoO/PE5dIZUkIgK4HliemVeVXU+zGfwlycyNwOapK5YDtzl1xcARET8CHgEOiojuiDi37JrUJ0cBZwHHRsTS+uvEsotqFodzSlLFeMUvSRVj8EtSxRj8klQxBr8kVYzBL0kVY/Cr8iJiRESc34TznBQRk4s+j7QzBr8EI4CGgz9qduX/nZMAg1+lcxy/Ki8iNs+M+mtgEXAYMBJoAb6WmQvqE3n9I7WJvI4ATgTOBs4EeqhNuLc4M78TEZOA7wNtwBvAfwY+CPwMeLX+mpGZzzbrM0q9Nf1h69Ju6FLgkMxsj4jBwNDMfC0iRgOPRsTmqTQOBM7JzEcj4khgBvARav9ALKE2pzvUHuD9V5m5MiL+FPhBZh5bf5+fZeYdzfxw0tYMfundAvhv9Qet/J7aVNlj6tuez8xH68tHAQsycwOwISL+AbbM9vhnwO216WAA2LtZxUuNMPild/sCtS6aIzLz7Yh4Dmitb3u9geP3AtZmZnsx5Unvnzd3JVgHDKsvDwdW1UP/GOCA7RzzEPCXEdFav8r/C4D6nO7/GhGnwpYbwR/Zxnmk0hj8qrzMXA08VH9wejvQERFPULt5u2I7x/wTtWm0HwfuAZ6gdtMWan81nBsR/wI8yTuP1LwVmBUR/1y/ASyVwlE90i6KiH0yc31EDAUeADo3P8dV2p3Zxy/tutn1L2S1Ajca+hoovOKXpIqxj1+SKsbgl6SKMfglqWIMfkmqGINfkirm/wPjXfPkVn2NFAAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, train_size=0.5, stratify=y)\n", "\n", "df_train = pd.DataFrame(y_train, columns=['target'])\n", "df_train['type'] = 'train'\n", "\n", "df_test = pd.DataFrame(y_test, columns=['target'])\n", "df_test['type'] = 'test'\n", "\n", "df_set = df_train.append(df_test)\n", "\n", "_ = sns.countplot(x='target', hue='type', data=df_set) \n", "\n", "print('train samples:', len(X_train))\n", "print('test samples', len(X_test))" ] }, { "cell_type": "code", "execution_count": 5, "id": "national-clinic", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "ExtraTreesClassifier(max_depth=3, n_estimators=50)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = ensemble.ExtraTreesClassifier(n_estimators=50, max_depth=3)\n", "model.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 6, "id": "plain-serial", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>0</th>\n", " <th>1</th>\n", " <th>2</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>29</td>\n", " <td>0</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>2</td>\n", " <td>33</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>24</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " 0 1 2\n", "0 29 0 0\n", "1 2 33 1\n", "2 0 0 24" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted = model.predict(X_test)\n", "\n", "confusion_matrix = pd.DataFrame(metrics.confusion_matrix(y_test, predicted))\n", "confusion_matrix" ] }, { "cell_type": "code", "execution_count": 7, "id": "stock-maldives", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAD4CAYAAACt8i4nAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAU+klEQVR4nO3df5Rd473H8fdnZhKjiR8hTIIQRBFU1I8iLUER1JWotrhUu+i4bpWU24rmXr9ayqof62p7VVIp2l6qxaW43EhpUFRoSn7TNEiazBShSaTNr+/9I0dMdTLnzMzZs/c883llPSvn7H3Oc772sj4ez97P3ooIzMwsOzV5F2BmljoHrZlZxhy0ZmYZc9CamWXMQWtmlrG6rH9g49E/9GUNGVvy87PyLsGsKurrUGf72HifcyvOnBW/+16nf68SHtGamWUs8xGtmVmXUvHGjw5aM0tLTW3eFfwDB62ZpUVdMu3aLg5aM0uLpw7MzDLmEa2ZWcY8ojUzy5hHtGZmGfNVB2ZmGfPUgZlZxjx1YGaWMY9ozcwy5qA1M8tYrU+GmZlly3O0ZmYZK+DUQfEqMjPrDKny1mY3qpf0W0m/lzRD0uWl7TtKelbSK5J+Jql3uZIctGaWFtVU3tr2N+DwiNgbGAaMlHQgcA1wQ0QMAZYAZ5bryEFrZmmp0og21llWetur1AI4HPhFafttwKhyJTlozSwtNbUVN0mNkqa2aI0tu5JUK2ka0AxMAv4AvB0Rq0sfWQBsW64knwwzs7S042RYRIwHxrexfw0wTNLmwL3Abh0pyUFrZmnJ4PKuiHhb0mPAQcDmkupKo9rtgIXlvu+pAzNLS5VOhknaqjSSRdLGwJHALOAx4KTSx84A7itXkke0ZpaW6l1HOxC4TVIt6wald0XEA5JmAndK+hbwO+CWch05aM0sLVW6H21EvAjs08r2ecAB7enLQWtmafESXDOzjBVwCa6D1szS4hGtmVm25KA1M8uWg9bMLGOqcdAWynZb9uGH5x/K1ptvTARMnDSb7z8wg70Gb8F3/2U4fep78WrzMr54w2MsXbEq73KT8dQTU7jm6itZu2Ytoz/9Gc78UmP5L1m79ORj7BFtwaxeu5axtz7LtHlv0re+F7+5bhSTpy3kpn/9BGNve5YnZyzm80d8mK+O+ghX3PF83uUmYc2aNVx15RXcPOFHNDQ0cOrnTmLEYYez85AheZeWjJ5+jIsYtMW7DqILLV6ygmnz3gRg2V9XMXvB22yzZR+GbLMZT85YDMCvpi1k1EGDc6wyLdNfepFBg3Zgu0GD6NW7NyOPPY7HH5ucd1lJ6enHWFLFrauUDVpJu0m6SNKNpXaRpN27oriutP1WfRm245Y8N7eZWa8v4fgDdgDgxOE7sl3/PjlXl47mpiYGDByw/v3WDQ00NTXlWFF6evwxVjtaF2kzaCVdBNzJupJ+W2oC7pA0Nvvyukaf+jruuOiTfG3iMyxdsYqzvzeFxmN256lrR9G3vhcrV6/Nu0Qzq1ARR7Tl5mjPBPaIiL87EyTpemAGcHVrXyrdPLcRoG7Y6dQNPqQKpWajrlbc8fVP8rMpr3DfM/MBmLvwHY6//GEAhmyzKcfsNyjHCtOydUMDixctXv++uamJhoaGHCtKT08/xjU1xZsRLVfRWmCbVrYPLO1rVUSMj4j9ImK/IocswA++fAhzFrzNjfdPX79tq83qgXULTMaetA8THpmdV3nJ2WPPvXjttfksWPA6q1au5OGHHuTQww7Pu6yk9PRj3B1HtGOAyZJeBl4vbdseGAKcm2FdXeLg3Rv458N24aX5b/HM9aMBuPQnzzFkm804+5ihANz3zHxunzw3zzKTUldXx8XjLuGcxrNYu3YNo0Z/miFDdsm7rKT0+GNcvIsOUES0/QGphnW3BHvvuTgLgedKj3goa+PRP2z7B6zTlvz8rLxLMKuK+rrOx2T/L9xZcea8cevJXRLLZa+jjYi1wDNdUIuZWacV8TraHr1gwczS4yW4ZmYZ84jWzCxjDlozs4w5aM3MMuagNTPLWvFy1kFrZmkp4hJcB62ZJaWIUwfFi34zs86o0m0SJQ2S9JikmZJmSDq/tP0ySQslTSu1Y8uV5BGtmSWliiPa1cCFEfGCpE2A5yVNKu27ISKurbQjB62ZJaVaQRsRi4BFpddLJc3i/Xu+tIunDswsKe25TaKkRklTW7RWn2IpaTCwD/BsadO5kl6UNFFSv3I1OWjNLCmqUcWt5b2zS238P/Qn9QXuBsZExF+Am4CdgWGsG/FeV64mTx2YWVKqedWBpF6sC9mfRsQ9ABHR1GL/BOCBcv04aM0sKdUKWq3r6BZgVkRc32L7wNL8LcBoYHpr32/JQWtmSanigHY4cDrwkqRppW3fAE6RNAwIYD5wdrmOHLRmlpQqXnXwJK1fbftQe/ty0JpZUmp8428zs2wVcAWug9bM0uIRrZlZxjyiNTPLWBHv3uWgNbOkFDBnHbRmlhbf+NvMLGMe0ZqZZcxztGZmGStgzjpozSwtHtGamWWsgDnroDWztPTIlWGv//gLWf9Ej9dv/3PzLiF5zU/fmHcJPUJ9XecvzfLUgZlZxgqYsw5aM0uLR7RmZhkrYM46aM0sLT3yZJiZWVfy1IGZWcYctGZmGStgzjpozSwtHtGamWWsgDnroDWztBTxqoPi3YrczKwTaqSKW1skDZL0mKSZkmZIOr+0fQtJkyS9XPq7X9maqvTPZmZWCFLlrYzVwIURMRQ4EPiypKHAWGByROwCTC69b5OD1sySIqni1paIWBQRL5ReLwVmAdsCJwC3lT52GzCqXE0OWjNLSo0qb5IaJU1t0Rpb61PSYGAf4FmgISIWlXYtBhrK1eSTYWaWlPacDIuI8cD4tj4jqS9wNzAmIv7SciQcESEpytZUcUVmZt2A2vGnbF9SL9aF7E8j4p7S5iZJA0v7BwLN5fpx0JpZUtozddAWrRu63gLMiojrW+y6Hzij9PoM4L5yNXnqwMySUsWVYcOB04GXJE0rbfsGcDVwl6QzgVeBz5bryEFrZkmpVs5GxJOwwfmFI9rTl4PWzJJSbiFCHhy0ZpaUIi7BddCaWVIKOKB10JpZWjx1YGaWseLFrIPWzBLjG3+bmWWsgOfCHLRmlhZfdWBmljFPHZiZZayAA1oHrZmlxSNaM7OMFS9mHbRmlpjaAs4dOGhLmhYv4puXXMySt94EiRNGf4bPnnp63mUlYaPedTx6yxh6966jrraWex/9Hd/6wUPcdOmpfHTo9gjxymvNfOmSH7N8xcq8y+32Lr9kHE9OeZx+W2zBXff8Mu9yupynDgqstraOr3z16+y6+1CWL1/Omad9hv0PPIgddxqSd2nd3t9WrmZk440sX7GSuroafjXxAv7vqZl8/dp7WLr8rwBcc+GJnHPyoVz7o0k5V9v9HX/CKD53yqlcMq7sw1mTVMCc9RMW3tN/q63YdfehAPTp04cddtyJPzeXfUKFVei9kWqvulrq6mqJiPUhC1C/US8iyj56ySrw0X33Z9NNN8+7jNzUSBW3Lqupo1+U9MVqFlIki/60kJdnz2KPPT+SdynJqKkRz9w5ltcmX82vnpnNc9NfBeDmy05j/qNXsevgBv7rzl/nXKWlQKq8dZXOjGgv39COlo/wvX3ihE78RNd7993ljPvaGM77t7H06ds373KSsXZtcODJVzPk6H9nvz13YOjOAwE4+7KfsNNR45j9x8WcdNS+OVdpKZBUcesqbc7RSnpxQ7to41nmLR/h+8ay1d3m/wdXr1rFuK+N4ahjjmPE4UfmXU6S3lm2gl9PnctRBw9l5h8WAetC+OePPM8FZxzJj+9/JucKrburLeAkbbmTYQ3A0cCSD2wX8JtMKspJRPDtb17CDjvuxMmnfSHvcpLSv19fVq1awzvLVlC/US+O+NhuXH/bo+w0qD/zXn8DgE8d+hHmzm/KuVJLQQGv7iobtA8AfSNi2gd3SHo8i4Ly8uK0F3j4wfvZeciHOeOUEwE4+8tjOPjjh+RcWfc3oP+mTLjidGpraqipEXdPeoH/fWIGkyeOYZM+GyPBS3MXct5VP8u71CR846ILeX7qb3n77bc59sgRNJ5zLqNOPCnvsrpMEYNWWZ/p7U5TB93VoE+MybuE5DU/fWPeJfQIm9R3PiYv/OWcijPnuuN37ZJY9nW0ZpaUIo5oHbRmlpQCngvzggUzS0udVHErR9JESc2SprfYdpmkhZKmldqx5fpx0JpZUqq8YOFWYGQr22+IiGGl9lC5Tjx1YGZJqebS2oiYImlwZ/vxiNbMktKeEW3LVayl1ljhz5wr6cXS1EK/ch920JpZUmpUeYuI8RGxX4s2voKfuAnYGRgGLAKuK/cFTx2YWVKyvvF3RKxfwihpAusWdrXJQWtmScn6OlpJAyNiUentaGB6W58HB62ZJUZVfGqYpDuAEUB/SQuAS4ERkoYBAcwHzi7Xj4PWzJJSzRFtRJzSyuZb2tuPg9bMkuIluGZmGfPDGc3MMlZbwItWHbRmlpSufOhipRy0ZpYUz9GamWWsgANaB62ZpaWmitfRVouD1syS4hGtmVnG6go4SeugNbOkeERrZpYxX95lZpaxAuasg9bM0lLAhWEOWjNLi6cOzMwy5qA1M8tY8WLWQWtmiSnggNZBa2Zp8f1ozcwy5qsOzMwy1iNPhvWtd5Znbclz38u7hOTtPe6RvEvoEeZcc3Sn+/DUgZlZxjx1YGaWMY9ozcwyVryYLeYo28ysw2qlils5kiZKapY0vcW2LSRNkvRy6e9+5fpx0JpZUqTKWwVuBUZ+YNtYYHJE7AJMLr1vk4PWzJKidvwpJyKmAG99YPMJwG2l17cBo8r146A1s6S0Z0QrqVHS1BatsYKfaIiIRaXXi4GGcl/wyTAzS0p7noIbEeOB8R39rYgISVG+JjOzhFR5jrY1TZIGrvstDQSay33BQWtmSamRKm4ddD9wRun1GcB95b7gqQMzS0o1nzYu6Q5gBNBf0gLgUuBq4C5JZwKvAp8t14+D1sySUsnVBJWKiFM2sOuI9vTjoDWzpBRwBa6D1szSUs0RbbU4aM0sKdWco60WB62ZJaVH3vjbzKwrFS9mHbRmlhiPaM3MMla8mHXQmllqCpi0DlozS4qnDszMMla8mHXQmllqCpi0DlozS4pXhpmZZayAU7QOWjNLSwFz1kFrZmlRAYe0DlozS0oBc9ZBa2ZpKWDOOmjNLDEFTFoHrZklpYiXd/kpuC089cQU/um4o/nUyCO5ZUKHH/VuZfg4V9+Azeq5vXF/HrxgOA9cMJzPD9/+7/Z/8RM7MOeao+n3oV45Vdh1uuBx4+3mEW3JmjVruOrKK7h5wo9oaGjg1M+dxIjDDmfnIUPyLi0pPs7ZWLN2LVc/MJuZf1pKn9613H3eQTz18pv8oXk5AzarZ/iH+7NwyYq8y+wSRTwZ5hFtyfSXXmTQoB3YbtAgevXuzchjj+PxxybnXVZyfJyz8eelK5n5p6UALF+5hnnNy2nYrB6Ai4/fle88NJeIPCvsOmrHn65SNmgl7SbpCEl9P7B9ZHZldb3mpiYGDByw/v3WDQ00NTXlWFGafJyzt22/enbfdhN+/9rbHDF0K5rf+RtzFi3Nu6wuU8SpgzaDVtJ5wH3AV4Dpkk5osfuqLAszs/b7UO9abjxtGFfdP5s1a4OzD9uJ/5z0St5ldSm1o3WVciPaLwH7RsQoYATwH5LOL+3bYJ2SGiVNlTS1u5zs2LqhgcWLFq9/39zURENDQ44VpcnHOTt1NeLG04fxy2mLmDSjme23/BDbbbEx951/MJMvOoQBm23EPecfRP++vfMuNVtVTFpJ8yW9JGmapKkdLancybCaiFgGEBHzJY0AfiFph7bKjIjxwHiAv66mW8wM7bHnXrz22nwWLHidhq0bePihB/n2d67Lu6zk+Dhn58qT9mBe83JufeJVAOYuXsbB33x8/f7JFx3CSd99miXvrsqnwC6SwY2/D4uINzrTQbmgbZI0LCKmAUTEMkmfAiYCe3Xmh4umrq6Oi8ddwjmNZ7F27RpGjf40Q4bskndZyfFxzsa+gzdn1L7bMmfRUv7n/IMAuP7hl5kyp1P50C0V8KIDFG2cipS0HbA6Iha3sm94RDxV7ge6y4jWrC17j3sk7xJ6hDnXHN3pnJzb9G7FmbPrgD5nA40tNo0v/R85AJL+CCwBAri55b72aHNEGxEL2thXNmTNzLpaey7bajnNuQEfj4iFkrYGJkmaHRFT2luTr6M1s6RU8/KuiFhY+rsZuBc4oCM1OWjNLCnVuuhAUh9Jm7z3GjgKmN6RmrwE18ySUsUbfzcA95b6qwP+OyIe7khHDlozS0q1cjYi5gF7V6MvB62ZJaWIl3c5aM0sLQVMWgetmSWliDf+dtCaWVKKeD9aB62ZJaXGQWtmlrXiJa2D1syS4qkDM7OMFTBnHbRmlhaPaM3MMlbFJbhV46A1s6QUL2YdtGaWmAIOaB20ZpYWrwwzM8ta8XLWQWtmaSlgzjpozSwtGTxuvNMctGaWlALmrJ8ZZmaWNY9ozSwpRRzROmjNLCm+vMvMLGMe0ZqZZcxBa2aWMU8dmJllrIgjWl/eZWZJUTta2b6kkZLmSHpF0tiO1uSgNbO0VClpJdUC3weOAYYCp0ga2pGSPHVgZkmp4hLcA4BXImIegKQ7gROAme3tKPOgra8r4Mx0GZIaI2J83nWkrLsd4znXHJ13Ce3W3Y5xtbQncyQ1Ao0tNo1vccy2BV5vsW8B8LGO1OSpg9Y1lv+IdZKPcfZ8jMuIiPERsV+Llsl/mBy0ZmatWwgMavF+u9K2dnPQmpm17jlgF0k7SuoNnAzc35GOfDKsdT1uXisHPsbZ8zHuhIhYLelc4BGgFpgYETM60pcioqrFmZnZ3/PUgZlZxhy0ZmYZc9C2UK3ldrZhkiZKapY0Pe9aUiVpkKTHJM2UNEPS+XnX1NN5jraktNxuLnAk6y5Mfg44JSLavQrENkzSIcAy4PaI2DPvelIkaSAwMCJekLQJ8Dwwyv8u58cj2vetX24XESuB95bbWRVFxBTgrbzrSFlELIqIF0qvlwKzWLfKyXLioH1fa8vt/C+ndWuSBgP7AM/mXEqP5qA1S5SkvsDdwJiI+Eve9fRkDtr3VW25nVneJPViXcj+NCLuybuens5B+76qLbczy5MkAbcAsyLi+rzrMQftehGxGnhvud0s4K6OLrezDZN0B/A0sKukBZLOzLumBA0HTgcOlzSt1I7Nu6iezJd3mZllzCNaM7OMOWjNzDLmoDUzy5iD1swsYw5aM7OMOWjNzDLmoDUzy9j/A2ZYW2d9gfQGAAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 2 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = sns.heatmap(confusion_matrix, annot=True, cmap=\"Blues\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "therapeutic-quest", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy: 0.966\n", "precision: 0.968\n", "recall: 0.966\n", "f1 score: 0.966\n" ] } ], "source": [ "print(\"accuracy: {:.3f}\".format(metrics.accuracy_score(y_test, predicted)))\n", "print(\"precision: {:.3f}\".format(metrics.precision_score(y_test, predicted, average='weighted')))\n", "print(\"recall: {:.3f}\".format(metrics.recall_score(y_test, predicted, average='weighted')))\n", "print(\"f1 score: {:.3f}\".format(metrics.f1_score(y_test, predicted, average='weighted')))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }