{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "<figure>\n",
    "<img src=\"../Imagenes/logo-final-ap.png\"  width=\"80\" height=\"80\" align=\"left\"/> \n",
    "</figure>\n",
    "\n",
    "# <span style=\"color:blue\"><left>Aprendizaje Profundo</left></span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <span style=\"color:red\"><center>Clasificación, Softmax, Iris</center></span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center> Clasificación con múltiples categorías</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##   <span style=\"color:blue\">Profesores</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Alvaro Mauricio Montenegro Díaz, ammontenegrod@unal.edu.co\n",
    "2. Daniel Mauricio Montenegro Reyes, dextronomo@gmail.com \n",
    "3. Campo Elías Pardo Turriago, cepardot@unal.edu.co "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##   <span style=\"color:blue\">Asesora Medios y Marketing digital</span>\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. Maria del Pilar Montenegro, pmontenegro88@gmail.com "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Asistentes</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5. Oleg Jarma, ojarmam@unal.edu.co \n",
    "6. Laura Lizarazo, ljlizarazore@unal.edu.co "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Contenido</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* [Introducción](#Introducción)\n",
    "* [Importa módulos](#Importa-módulos)\n",
    "* [Funciones de activación](#Funciones-de-activación)\n",
    "* [El conjunto de datos Iris](#El-conjunto-de-datos-Iris)\n",
    "* [Lectura de datos](#Lectura-de-datos)\n",
    "* [Preprocesamiento](#Preprocesamiento)\n",
    "* [Crea el modelo usando la API funcional](#Crea-el-modelo-usando-la-API-funcional)\n",
    "* [Compila](#Compila)\n",
    "* [Entrena](#Entrena)\n",
    "* [Evaluación del modelo](#Evaluación-del-modelo)\n",
    "* [Predicciones](#Predicciones)\n",
    "* [Matriz de confusión](#Matriz-de-confusión)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Introducción</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eeta lección está dedicada a un modelo  de clasificación con mútliples categorías, que corresponde a la generalización natural del modelo logístico. \n",
    "\n",
    "* Practicaremos la codificación *one-hot* para los datos de salida.\n",
    "* También usaremos la API funcional de tf.keras, que es una forma de programación maś flexible y poderosa que el modelo Sequential\n",
    "* Usaremos las funciones *relu* para capas intermedias y entrada y la función de activación *softmax* para la salida, debido a que se tienen varias clases. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Importa módulos</span>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Usaremos las bibliotecas\n",
    "* *seaborn* para gráficas un poco más elegantes\n",
    "* *sklearn* para utilidades de estandarizacion de datos y matriz de confusión\n",
    "\n",
    "Puede usar las siguientes instrucciones para isntalr desde la consola."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !conda install -c anaconda seaborn\n",
    "# !conda install -c intel scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Versión de Tensorflow: 2.4.1\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import tensorflow as tf\n",
    "#\n",
    "from tensorflow.keras.models import Model\n",
    "#\n",
    "from tensorflow.keras.layers import Dense, Input, Activation, Dropout\n",
    "#\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "from tensorflow.keras.utils import plot_model\n",
    "#\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import confusion_matrix\n",
    "#\n",
    "#from sklearn import KFold\n",
    "print(\"Versión de Tensorflow:\", tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Funciones de activación</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Relu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dada la salida del sumador digamos $y=\\mathbf{w}'\\mathbf{x} +b$, la función de activación *relu* esta definida por\n",
    "\n",
    "$$\n",
    "\\text{relu}(y) = \\begin{cases} &0, \\text{ si } y\\le 0,\\\\\n",
    " &y, \\text{ en otro caso } \\end{cases}\n",
    "$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Softmax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Dados los valores $x_1,\\ldots, x_n$ la función *softmax *  es definida por\n",
    "\n",
    "$$\n",
    "\\text{softmax}(x_i) = \\frac{e^{x_i}}{\\sum_{j=1}^{n} e^{x_j}}\n",
    "$$\n",
    "\n",
    "Es decir, *softmax* transforma los valores en un función de probabilidad."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">El conjunto de datos Iris</span> \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Este conjunto de datos fue introducido por sir [Ronald Fisher]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lectura de datos y primera vista de los datos"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Bajamos los datos de Internet usando *tf.keras.utils* y luego los cargamos en dataframes de Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SepalLength</th>\n",
       "      <th>SepalWidth</th>\n",
       "      <th>PetalLength</th>\n",
       "      <th>PetalWidth</th>\n",
       "      <th>Species</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.2</td>\n",
       "      <td>1.5</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6.9</td>\n",
       "      <td>3.1</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.3</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>6.0</td>\n",
       "      <td>3.4</td>\n",
       "      <td>4.5</td>\n",
       "      <td>1.6</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.5</td>\n",
       "      <td>2.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1.3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   SepalLength  SepalWidth  PetalLength  PetalWidth  Species\n",
       "0          5.9         3.0          4.2         1.5        1\n",
       "1          6.9         3.1          5.4         2.1        2\n",
       "2          5.1         3.3          1.7         0.5        0\n",
       "3          6.0         3.4          4.5         1.6        1\n",
       "4          5.5         2.5          4.0         1.3        1"
      ]
     },
     "execution_count": 201,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# nombres de las columnas de los datos\n",
    "col_names = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']\n",
    "target_dimensions = ['Setosa', 'Versicolor', 'Virginica']\n",
    "\n",
    "# lee los datos\n",
    "training_data_path = tf.keras.utils.get_file(\"iris_training.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv\")\n",
    "test_data_path = tf.keras.utils.get_file(\"iris_test.csv\", \"https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv\")\n",
    "\n",
    "training = pd.read_csv(training_data_path, names=col_names, header=0)\n",
    "test = pd.read_csv(test_data_path, names=col_names, header=0)\n",
    "\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Pre-procesamiento</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "La variable objetivo (target) tiene tres categorías. Usaremos la codificación one-hot para transformar las codificaciones en vectors binarios."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Codificación one-hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train= pd.DataFrame(to_categorical(training.Species))\n",
    "y_train.columns = target_dimensions\n",
    "\n",
    "y_test = pd.DataFrame(to_categorical(test.Species))\n",
    "y_test.columns = target_dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Setosa</th>\n",
       "      <th>Versicolor</th>\n",
       "      <th>Virginica</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Setosa  Versicolor  Virginica\n",
       "0      0.0         1.0        0.0\n",
       "1      0.0         0.0        1.0\n",
       "2      1.0         0.0        0.0\n",
       "3      0.0         1.0        0.0\n",
       "4      0.0         1.0        0.0\n",
       "5      0.0         1.0        0.0\n",
       "6      1.0         0.0        0.0\n",
       "7      0.0         0.0        1.0\n",
       "8      0.0         1.0        0.0\n",
       "9      0.0         0.0        1.0\n",
       "10     0.0         0.0        1.0\n",
       "11     1.0         0.0        0.0\n",
       "12     0.0         0.0        1.0\n",
       "13     0.0         1.0        0.0\n",
       "14     0.0         1.0        0.0\n",
       "15     1.0         0.0        0.0\n",
       "16     0.0         1.0        0.0\n",
       "17     1.0         0.0        0.0\n",
       "18     1.0         0.0        0.0\n",
       "19     0.0         0.0        1.0\n",
       "20     1.0         0.0        0.0\n",
       "21     0.0         1.0        0.0\n",
       "22     0.0         0.0        1.0\n",
       "23     0.0         1.0        0.0\n",
       "24     0.0         1.0        0.0\n",
       "25     0.0         1.0        0.0\n",
       "26     1.0         0.0        0.0\n",
       "27     0.0         1.0        0.0\n",
       "28     0.0         0.0        1.0\n",
       "29     0.0         1.0        0.0"
      ]
     },
     "execution_count": 203,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Elimina columna Species del dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train_species = training.pop('Species')\n",
    "#test.drop(['Species'], axis=1, inplace=True)\n",
    "y_test_species = test.pop('Species') # extrae la columna y la coloca en y_test_species\n",
    "#\n",
    "#Si necesita subir al dataframe la recodificación use estas líneas\n",
    "#training = training.join(y_train )\n",
    "#test = test.join(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normaliza los features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5.845      3.065      3.73916667 1.19666667]\n"
     ]
    }
   ],
   "source": [
    "# crea el objeto StandardScaler\n",
    "scaler = StandardScaler()\n",
    "\n",
    "# Ajusta los parámetros del scaler\n",
    "scaler.fit(training)\n",
    "print (scaler.mean_)\n",
    "\n",
    "# escala training y test\n",
    "x_train = scaler.transform(training)\n",
    "x_test = scaler.transform(test)\n",
    "\n",
    "# labels ( no requieren escalación)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Crea el modelo usando la API funcional</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "La API funcional de Keras es bastante más flexible y poderosa que el modelo Sequential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 266,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_22\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "capa_entrada (InputLayer)    [(None, 4)]               0         \n",
      "_________________________________________________________________\n",
      "activation_16 (Activation)   (None, 4)                 0         \n",
      "_________________________________________________________________\n",
      "primera_capa_oculta (Dense)  (None, 8)                 40        \n",
      "_________________________________________________________________\n",
      "dropout_20 (Dropout)         (None, 8)                 0         \n",
      "_________________________________________________________________\n",
      "segunda_capa_oculta (Dense)  (None, 16)                144       \n",
      "_________________________________________________________________\n",
      "dropout_21 (Dropout)         (None, 16)                0         \n",
      "_________________________________________________________________\n",
      "capa_salida (Dense)          (None, 3)                 51        \n",
      "=================================================================\n",
      "Total params: 235\n",
      "Trainable params: 235\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 266,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Con la API funcion se requiere la capa Input que transforma la entrada \n",
    "# en un tensor de tensorflow directamente\n",
    "#\n",
    "inputs = Input(shape=(4,),name='capa_entrada')\n",
    "#\n",
    "# vamos construyendo capa por capa\n",
    "x = Activation('relu')(inputs)\n",
    "x = Dense(8, activation='relu',name='primera_capa_oculta')(x)\n",
    "x = Dropout(0.2)(x)\n",
    "x = Dense(16, activation='relu', name='segunda_capa_oculta')(x)\n",
    "x = Dropout(0.2)(x)\n",
    "outputs = Dense(3, activation='softmax', name='capa_salida')(x)\n",
    "\n",
    "# Creamos ahora el modelo\n",
    "model_iris = Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "model_iris.summary()\n",
    "plot_model(model_iris, to_file='../Imagenes/iris_model.png', \n",
    "           show_shapes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Compila</span> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 267,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_iris.compile(optimizer='adam',\n",
    "    loss='categorical_crossentropy',\n",
    "    metrics=['accuracy']\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Entrena</span> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 268,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PrintDot(tf.keras.callbacks.Callback):\n",
    "    def on_epoch_end(self, epoch, logs):\n",
    "        print('.', end='')\n",
    "\n",
    "epochs = 200\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 269,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "........................................................................................................................................................................................................\n",
      "Hecho\n",
      "Resultados finales de pérdida y exactitud\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>val_loss</th>\n",
       "      <th>val_accuracy</th>\n",
       "      <th>epoch</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>195</th>\n",
       "      <td>0.303951</td>\n",
       "      <td>0.842593</td>\n",
       "      <td>0.360190</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>195</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>196</th>\n",
       "      <td>0.281763</td>\n",
       "      <td>0.842593</td>\n",
       "      <td>0.361281</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>197</th>\n",
       "      <td>0.290119</td>\n",
       "      <td>0.861111</td>\n",
       "      <td>0.361293</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>197</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>198</th>\n",
       "      <td>0.292439</td>\n",
       "      <td>0.898148</td>\n",
       "      <td>0.360291</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>198</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>199</th>\n",
       "      <td>0.293885</td>\n",
       "      <td>0.907407</td>\n",
       "      <td>0.360623</td>\n",
       "      <td>0.916667</td>\n",
       "      <td>199</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         loss  accuracy  val_loss  val_accuracy  epoch\n",
       "195  0.303951  0.842593  0.360190      0.916667    195\n",
       "196  0.281763  0.842593  0.361281      0.916667    196\n",
       "197  0.290119  0.861111  0.361293      0.916667    197\n",
       "198  0.292439  0.898148  0.360291      0.916667    198\n",
       "199  0.293885  0.907407  0.360623      0.916667    199"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history = model_iris.fit(x_train, y_train,\n",
    "                    batch_size= 16,\n",
    "                    epochs= epochs,\n",
    "                    validation_split=0.1, verbose=0,\n",
    "                    callbacks=[PrintDot()])     \n",
    "print('\\nHecho')\n",
    "print('Resultados finales de pérdida y exactitud\\n')\n",
    "# presenta la última parte de la historia\n",
    "hist = pd.DataFrame(history.history)\n",
    "hist['epoch'] = history.epoch\n",
    "hist.tail()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Evaluación del modelo</span> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def plot_history(history):\n",
    "  hist = pd.DataFrame(history.history)\n",
    "  hist['epoch'] = history.epoch\n",
    "\n",
    "  plt.figure()\n",
    "  plt.xlabel('Epoch')\n",
    "  plt.ylabel('Pérdida')\n",
    "  plt.plot(hist['epoch'], hist['loss'],\n",
    "           label='Pérdida entrenamiento')\n",
    "  plt.plot(hist['epoch'], hist['val_loss'],\n",
    "           label = 'Pérdida validación')\n",
    "  plt.ylim([0,2])\n",
    "  plt.legend()\n",
    "\n",
    "  plt.figure()\n",
    "  plt.xlabel('Epoch')\n",
    "  plt.ylabel('Exactitud')\n",
    "  plt.plot(hist['epoch'], hist['accuracy'],\n",
    "           label='Exactitud entrenamiento')\n",
    "  plt.plot(hist['epoch'], hist['val_accuracy'],\n",
    "           label = 'Exactitud Validación')\n",
    "  plt.ylim([0,1])\n",
    "  plt.legend()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Predicciones</span> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f99d8514940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    }
   ],
   "source": [
    "# Predicting the Test set results\n",
    "y_pred = model_iris.predict(x_test)\n",
    "y_pred_c = np.argmax(y_pred, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Matriz de confusión</span> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 272,
   "metadata": {},
   "outputs": [],
   "source": [
    "cm = confusion_matrix(y_test_species, y_pred_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our accuracy is 93.33333333333333%\n"
     ]
    }
   ],
   "source": [
    "print(\"Our accuracy is {}%\".format(((cm[0][0] + cm[1][1]+ cm[2][2])/y_test_species.shape[0])*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAD4CAYAAACt8i4nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAStElEQVR4nO3de5CV9X3H8c/3wAoKGjVOhN0lWSw0YrWKRayhUZQqVkGwRtEGNdbJJlYNdjoa09o4TWPGjC1TZBybNV6jqIhavCaiEW/1AqJjuSgRMbDLIiripSXu7jnf/sERj7BwLvv8znP4nfeL+Q17nsP+zncemC/f/T6/3/OYuwsAEE4m7QAAIHYkWgAIjEQLAIGRaAEgMBItAATWP/QHvDfxGJY1BDbkyTfTDgFIRE9Xh/V1ju733io55zTsd0CfP68UVLQAEFjwihYAqiqXTTuC7ZBoAcQl25N2BNsh0QKIinsu7RC2Q6IFEJcciRYAwqKiBYDAuBgGAIFR0QJAWM6qAwAIjIthABAYrQMACIyLYQAQWA1WtNxUBkBcsj2ljyLM7CYz22BmSwuOXWNmr5vZa2Z2v5ntXWweEi2AuORypY/ibpF04jbHFkg62N3/VNJKST8qNgmJFkBU3LMlj+Jz+dOSNm5z7DF3/6wcfkFSc7F5SLQA4uK5koeZtZrZ4oLRWuan/a2kR4v9IS6GAYhLGeto3b1NUlslH2Nm/ySpR9Idxf4siRZAXKqw6sDMzpU0SdIEdy/66BwSLYC4ZLuDTm9mJ0r6oaRj3P3/SvkeEi2AuCS4BdfM7pQ0XtJ+ZtYu6UptWWUwQNICM5OkF9z9+zubh0QLIC4Jtg7c/axeDt9Y7jwkWgBx4aYyABAYiRYAwvLAF8MqQaIFEJcavKkMiRZAXGgdAEBgVLQAEBgVLQAERkULAIH11N5TcLlNYoGBp56uvdtu0d6/uFl7Xv5jqWG3tEOK0sQTxmvZ0qf1+vJnddmlF6YdTpTq+hyXcZvEaiHR5mW+vJ92n3qaNl3Uqk3fO0/ql9GA8celHVZ0MpmMrp11lSZNnq5DDj1W06ZN1ahRI9MOKyp1f46TfcJCIki0hfr1kw0YIGW2/J57/720I4rO2CNGa9Wqt7V69Rp1d3dr7tz5OmXyxLTDikrdn+MarGiL9mjN7EBJUyQ1SXJJ6yQ94O4rAsdWVbn339PmeXdp31/NlX/apa4li9S9ZHHaYUWnsWmI1rav2/q6vaNTY48YnWJE8an7c1yDqw52WtGa2Q8l3SXJJL0kaVH+6zvN7PLw4VWPDR6s3Y76C20890xt/Ju/lg0cqAHHHZ92WNHJ31buC0q4bzLKUPfneBesaM+X9Cfu/oXNw2Y2U9IySVf39k355+60StK/HzRS5zQPTSDUsBpGj1Fufaf8ww8lSV3PPaP+Bx2sT3+7IOXI4tLR3qlhzY1bXzc3DVVn5zspRhSfuj/Hu+Cqg5ykxl6OD82/1yt3b3P3Me4+ZldIspKU2/CO+o86SBowQJLUcNjhyq75fcpRxWfR4lc1YsRwtbQMU0NDg844Y4oefOixtMOKSt2fY/fSR5UUq2gvkfSEmf1O0tr8sa9KGiHpopCBVVvPGyvU9cxT2vu6G6RsVj1vvqk/PPpg2mFFJ5vNasYlV+iRh+eoXyajW269W8uXr0w7rKjU/TmuwR6tFevdmFlG0lhtuRhmktolLfJSHoou6b2Jx9RRcygdQ558M+0QgET0dHVs32Au0+Y7/rnknLP7t/+1z59XiqKrDtw9J+mFKsQCAH3HFlwACCxb0g/bVUWiBRCXGuzRkmgBxIVECwCB0aMFgLA8V3sLnbipDIC4JHj3LjO7ycw2mNnSgmP7mtkCM/td/vd9is1DogUQl2y29FHcLZJO3ObY5ZKecPeRkp7Iv94pEi2AuCRY0br705I2bnN4iqRb81/fKmlqsXno0QKIS/hVB/u7e6ckuXunmX2l2DdQ0QKISxk3lTGzVjNbXDBaQ4RERQsgLmVUtO7eJqmtzE94x8yG5qvZoZI2FPsGKloAccl56aMyD0g6N//1uZLmF/sGKloAcUnwXgdmdqek8ZL2M7N2SVdqywMP5prZ+ZLWSDq92DwkWgBR8QQvhrn7WTt4a0I585BoAcSlBneGkWgBxIV7HQBAYFS0ABBYDzf+BoCwaB0AQGC0DgAgrCSXdyWFRAsgLlS0ABAYiRYAAuNx4wAQVi0+M4xECyAuJFoACIxVBwAQGBUtAARGogWAsDxbh62Dwxdt+6ReJG3zumfSDiF6uzd+M+0QUCoqWgAIi+VdABAaiRYAAqu9Fi2JFkBcvKf2Mi2JFkBcai/PkmgBxIWLYQAQGhUtAIRVixVtJu0AACBRuTJGEWb292a2zMyWmtmdZjawkpBItACi4j2lj50xsyZJP5A0xt0PltRP0pmVxETrAEBUEn7aeH9Ju5tZt6Q9JK2rZBIqWgBxKaN1YGatZra4YLR+No27d0j6N0lrJHVK+tDdH6skJCpaAFEpp6J19zZJbb29Z2b7SJoiabikTZLuMbPp7n57uTFR0QKIiudKH0X8paTV7v6uu3dLuk/SNyqJiYoWQFQ8a0lNtUbSn5vZHpI2S5ogaXElE5FoAUQlqYth7v6imc2TtERSj6RXtIM2QzEkWgBR8VxiFa3c/UpJV/Z1HhItgKgkvLwrESRaAFFxT66iTQqJFkBUqGgBILBccqsOEkOiBRCVJC+GJYVECyAqJFoACMxr73a0JFoAcaGiBYDAWN4FAIFlWXUAAGFR0QJAYPRoASAwVh0AQGBUtAAQWDZXew+Oqb2IUnTN7J9oyRsLteC5+9IOJSpX/Gymjj75TE2d/v2tx2a33aZTz7lAp517ob57yT9qw7vvpxhhfCaeMF7Llj6t15c/q8suvTDtcKrKvfRRLSTaAvfMma9zTr8g7TCiM/Wk4/WfM3/6hWPnffs03X/b9br31ut0zLgjdf3Nc1KKLj6ZTEbXzrpKkyZP1yGHHqtp06Zq1KiRaYdVNTm3kke1kGgLvPT8y9r0wYdphxGdMYcdoi/ttecXjg0eNGjr15s3/0FWe221XdbYI0Zr1aq3tXr1GnV3d2vu3Pk6ZfLEtMOqGncreVRLxT1aMzvP3W9OMhjUl1m/uEUP/PoJ7TlokG6afXXa4USjsWmI1rav2/q6vaNTY48YnWJE1VWLqw76UtH+y47eMLNWM1tsZos/+XRjHz4CMZvxve/oift/pZNPOFZz7n0w7XCiYb38eOC1mH0C2eVaB2b22g7G/0jaf0ff5+5t7j7G3ccMHrBv4kEjLiefMF6PL3wu7TCi0dHeqWHNjVtfNzcNVWfnOylGVF3ZXKbkUS3FWgf7S5oo6YNtjpuk/w4SEerC79d26GvDmiRJTz7zgoZ/rTnliOKxaPGrGjFiuFpahqmjY73OOGOKzj6nflYe1GLtXizRPiRpsLu/uu0bZrYwSEQpmn3Dz3XUuCO0z5f31otLH9fMq6/T3bffn3ZYu7xLr7xai155TZs2faQJU6fr784/W888v0hvr2mXZUyNQ76iH196cdphRiObzWrGJVfokYfnqF8mo1tuvVvLl69MO6yqqWZLoFQWunfz1X0PqcX/YKKyauX8tEOI3u6N30w7hLrQ09XR5yz53JBvlZxzxq2fV5WszPIuAFHJlTGKMbO9zWyemb1uZivM7KhKYmILLoCouBItUmdJ+rW7f8vMdpO0RyWTkGgBRKUnoR6tme0l6WhJ35Ekd++S1FXJXLQOAETFZSWPwjX/+dFaMNUBkt6VdLOZvWJmvzSzQTv42J0i0QKISjk92sI1//nRVjBVf0mHS7re3UdL+l9Jl1cSE4kWQFTKqWiLaJfU7u4v5l/P05bEWzYSLYCoJLXqwN3XS1prZl/PH5ogaXklMXExDEBUssmuOrhY0h35FQdvSTqvkklItACikuSTbPK7Ysf0dR4SLYCo5JKtaBNBogUQlVrc80+iBRCVUrbWVhuJFkBUcjX4XCQSLYCoZNMOoBckWgBRSXLVQVJItACiwqoDAAiMVQcAEBitAwAIjOVdABBYlooWAMKiogWAwEi0ABBYQo8MSxSJFkBUqGgBIDC24AJAYKyjBYDAaB0AQGAkWgAIjHsdAEBg9GgBILC6XHWw7pONoT+i7v3RH09JO4ToLW05NO0QUKJcDTYPqGgBRIWLYQAQWO3Vs1Im7QAAIEm5MkYpzKyfmb1iZg9VGhMVLYCo9FjiNe0MSSsk7VXpBFS0AKLiZYxizKxZ0smSftmXmEi0AKJSTuvAzFrNbHHBaN1muv+QdJn6eI2N1gGAqJSzvMvd2yS19faemU2StMHdXzaz8X2JiUQLICoJdmjHSTrFzE6SNFDSXmZ2u7tPL3ciWgcAopLUqgN3/5G7N7t7i6QzJf22kiQrUdECiEy2BlfSkmgBRCXEzjB3XyhpYaXfT6IFEBWnogWAsLjXAQAExt27ACCw2kuzJFoAkempwVRLogUQFS6GAUBgXAwDgMCoaAEgMCpaAAgs61S0ABAU62gBIDB6tAAQGD1aAAiM1gEABEbrAAACY9UBAARG6wAAAuNiGAAERo8WAAKrxdYBjxsvMPGE8Vq29Gm9vvxZXXbphWmHE6VrZv9ES95YqAXP3Zd2KNHabXiTWubP3jpGLpmnfc6dknZYVePuJY9qIdHmZTIZXTvrKk2aPF2HHHqspk2bqlGjRqYdVnTumTNf55x+QdphRK1rdYfennLxlnHqDPnmP+jjBc+nHVbVZOUlj2oh0eaNPWK0Vq16W6tXr1F3d7fmzp2vUyZPTDus6Lz0/Mva9MGHaYdRN/Y46lB1rVmvnnUb0g6lanLykke1FE20ZnagmU0ws8HbHD8xXFjV19g0RGvb12193d7RqcbGISlGBPTdXicfo48eXph2GFW1y7UOzOwHkuZLuljSUjMrbPT8LGRg1WZm2x2r5l8EkLiG/ho84Uh9/OizaUdSVUlVtGY2zMyeNLMVZrbMzGZUGlOxVQfflfRn7v6JmbVImmdmLe4+S9L2menzAFsltUqS9fuSMplBlcZXNR3tnRrW3Lj1dXPTUHV2vpNiREDfDD56jD5dtkrZ9zelHUpVJbi8q0fSP7j7EjPbU9LLZrbA3ZeXO1Gx1kE/d/9Ektz9bUnjJf2Vmc3UThKtu7e5+xh3H7MrJFlJWrT4VY0YMVwtLcPU0NCgM86YogcfeiztsICK7TXpGH300FNph1F1WfeSx864e6e7L8l//bGkFZKaKompWKJdb2aHFXzwJ5ImSdpP0iGVfGCtymazmnHJFXrk4Tla+tpCzZv3oJYvX5l2WNGZfcPP9V+/uV0HjGjRi0sf17Tpp6YdUpRs4AAN+sZoffzYc2mHUnXltA7MrNXMFheM1t7mzP9EP1rSi5XEZDvrQ5pZs6Qed1/fy3vj3L3o32L/3ZpodAbWOHjftEOI3mP7DUs7hLpw4MpHdviTcqmOajq25JzzfMeTRT8vvxDgKUlXuXtFC8B32qN19/advFd//1UCqHlJXsQ2swZJ90q6o9IkK7EFF0Bkklofa1uWIt0oaYW7z+zLXGxYABAVL+NXEeMknS3pODN7NT9OqiQmKloAUcl6MjdKdPdntZPVVeUg0QKISi1uNCLRAohKLd4mkUQLICrc+BsAAsvROgCAsKhoASCwpFYdJIlECyAqtA4AIDBaBwAQGBUtAARGRQsAgWU9m3YI2yHRAogKW3ABIDC24AJAYFS0ABAYqw4AIDBWHQBAYGzBBYDA6NECQGD0aAEgMCpaAAiMdbQAEBgVLQAExqoDAAiMi2EAEFgttg4yaQcAAEnyMn4VY2YnmtkbZvammV1eaUxUtACiklRFa2b9JF0n6XhJ7ZIWmdkD7r683LlItACikmCPdqykN939LUkys7skTZFUe4m2p6vDQn9G0sys1d3b0o4jZpzj8Or1HJeTc8ysVVJrwaG2gnPWJGltwXvtko6sJCZ6tL1rLf5H0Eec4/A4x0W4e5u7jykYhf8x9ZawKyqXSbQA0Lt2ScMKXjdLWlfJRCRaAOjdIkkjzWy4me0m6UxJD1QyERfDeld3fa0UcI7D4xz3gbv3mNlFkn4jqZ+km9x9WSVzWS0u7gWAmNA6AIDASLQAEBiJtkBS2+2wY2Z2k5ltMLOlaccSKzMbZmZPmtkKM1tmZjPSjqne0aPNy2+3W6mC7XaSzqpkux12zMyOlvSJpNvc/eC044mRmQ2VNNTdl5jZnpJeljSVf8vpoaL93Nbtdu7eJemz7XZIkLs/LWlj2nHEzN073X1J/uuPJa3Qll1OSAmJ9nO9bbfjHyd2aWbWImm0pBfTjaS+kWg/l9h2O6AWmNlgSfdKusTdP0o7nnpGov1cYtvtgLSZWYO2JNk73P2+tOOpdyTazyW23Q5Ik5mZpBslrXD3mWnHAxLtVu7eI+mz7XYrJM2tdLsddszM7pT0vKSvm1m7mZ2fdkwRGifpbEnHmdmr+XFS2kHVM5Z3AUBgVLQAEBiJFgACI9ECQGAkWgAIjEQLAIGRaAEgMBItAAT2/8xu+2nBxcyMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.heatmap(cm,annot=True)\n",
    "plt.savefig('h.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Regresar al inicio](#Contenido)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <span style=\"color:blue\">Exploración interna de la red</span> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cálculo de la salida de los datos de entrenamiento"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 275,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.  , 0.01, 0.99],\n",
       "       [0.51, 0.44, 0.05],\n",
       "       [0.  , 0.38, 0.62],\n",
       "       [0.8 , 0.18, 0.02],\n",
       "       [1.  , 0.  , 0.  ],\n",
       "       [0.98, 0.02, 0.  ],\n",
       "       [1.  , 0.  , 0.  ],\n",
       "       [0.  , 0.02, 0.98],\n",
       "       [0.02, 0.79, 0.19],\n",
       "       [1.  , 0.  , 0.  ]], dtype=float32)"
      ]
     },
     "execution_count": 275,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = x_train\n",
    "outputs = model_iris(inputs)\n",
    "outputs.numpy().round(2)[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extrae la segunda capa oculta para estos datos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 217,
   "metadata": {},
   "outputs": [],
   "source": [
    "# modelo Sequential\n",
    "#layer_2 = tf.keras.models.Model(\n",
    "#    inputs=model_iris.inputs,\n",
    "#    outputs=model_iris.get_layer(name='segunda_capa_oculta').output,\n",
    "#)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 276,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:6 out of the last 12 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f99ac40e940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    }
   ],
   "source": [
    "# API funcional\n",
    "# 1. crea un nuevo modelo\n",
    "# 2. Compila\n",
    "# 3. Predice\n",
    "\n",
    "inputs = x_train\n",
    "model = Model(model_iris.input, model_iris.get_layer(name='segunda_capa_oculta').output)\n",
    "model.compile()\n",
    "output = model.predict(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 277,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(120, 16)"
      ]
     },
     "execution_count": 277,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Crea tabla de datos para hacer un gráfico tsne"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_data = np.hstack([output, np.array(y_train_species).reshape(y_train_species.shape[0],1)])\n",
    "plot_data = pd.DataFrame(plot_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>13</th>\n",
       "      <th>14</th>\n",
       "      <th>15</th>\n",
       "      <th>16</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.467787</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.731057</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.510913</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.827167</td>\n",
       "      <td>3.498898</td>\n",
       "      <td>2.964854</td>\n",
       "      <td>1.161228</td>\n",
       "      <td>1.649435</td>\n",
       "      <td>2.353893</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.230899</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.215546</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.113442</td>\n",
       "      <td>0.213827</td>\n",
       "      <td>0.396466</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.100383</td>\n",
       "      <td>0.180319</td>\n",
       "      <td>0.579343</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.981748</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.150882</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.072097</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.005695</td>\n",
       "      <td>1.802041</td>\n",
       "      <td>1.445802</td>\n",
       "      <td>0.958601</td>\n",
       "      <td>0.674698</td>\n",
       "      <td>1.188161</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.418447</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.336955</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.276314</td>\n",
       "      <td>0.321667</td>\n",
       "      <td>0.520616</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.088640</td>\n",
       "      <td>0.445429</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.559006</td>\n",
       "      <td>0.097274</td>\n",
       "      <td>1.694688</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.285218</td>\n",
       "      <td>1.820933</td>\n",
       "      <td>1.972289</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.115792</td>\n",
       "      <td>0.183403</td>\n",
       "      <td>0.435880</td>\n",
       "      <td>0.432974</td>\n",
       "      <td>0.678667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.090053</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>116</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.171450</td>\n",
       "      <td>0.110357</td>\n",
       "      <td>0.347058</td>\n",
       "      <td>0.385124</td>\n",
       "      <td>0.669227</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.025612</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.230899</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.215546</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.113442</td>\n",
       "      <td>0.213827</td>\n",
       "      <td>0.396466</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.100383</td>\n",
       "      <td>0.180319</td>\n",
       "      <td>0.579343</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.230899</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.215546</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.113442</td>\n",
       "      <td>0.213827</td>\n",
       "      <td>0.396466</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.100383</td>\n",
       "      <td>0.180319</td>\n",
       "      <td>0.579343</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.230899</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.215546</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.113442</td>\n",
       "      <td>0.213827</td>\n",
       "      <td>0.396466</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.100383</td>\n",
       "      <td>0.180319</td>\n",
       "      <td>0.579343</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>120 rows × 17 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           0         1         2         3         4         5         6   \\\n",
       "0    2.467787  0.000000  0.731057  0.000000  0.510913  0.000000  0.000000   \n",
       "1    0.000000  0.230899  0.000000  0.215546  0.000000  0.113442  0.213827   \n",
       "2    0.981748  0.000000  0.150882  0.000000  0.072097  0.000000  0.000000   \n",
       "3    0.000000  0.418447  0.000000  0.336955  0.000000  0.276314  0.321667   \n",
       "4    0.000000  2.559006  0.097274  1.694688  0.000000  2.285218  1.820933   \n",
       "..        ...       ...       ...       ...       ...       ...       ...   \n",
       "115  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "116  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "117  0.000000  0.230899  0.000000  0.215546  0.000000  0.113442  0.213827   \n",
       "118  0.000000  0.230899  0.000000  0.215546  0.000000  0.113442  0.213827   \n",
       "119  0.000000  0.230899  0.000000  0.215546  0.000000  0.113442  0.213827   \n",
       "\n",
       "           7         8         9         10        11        12        13  \\\n",
       "0    0.000000  1.827167  3.498898  2.964854  1.161228  1.649435  2.353893   \n",
       "1    0.396466  0.000000  0.100383  0.180319  0.579343  0.000000  0.000000   \n",
       "2    0.000000  1.005695  1.802041  1.445802  0.958601  0.674698  1.188161   \n",
       "3    0.520616  0.000000  0.000000  0.088640  0.445429  0.000000  0.000000   \n",
       "4    1.972289  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000   \n",
       "..        ...       ...       ...       ...       ...       ...       ...   \n",
       "115  0.115792  0.183403  0.435880  0.432974  0.678667  0.000000  0.090053   \n",
       "116  0.171450  0.110357  0.347058  0.385124  0.669227  0.000000  0.025612   \n",
       "117  0.396466  0.000000  0.100383  0.180319  0.579343  0.000000  0.000000   \n",
       "118  0.396466  0.000000  0.100383  0.180319  0.579343  0.000000  0.000000   \n",
       "119  0.396466  0.000000  0.100383  0.180319  0.579343  0.000000  0.000000   \n",
       "\n",
       "      14   15   16  \n",
       "0    0.0  0.0  2.0  \n",
       "1    0.0  0.0  1.0  \n",
       "2    0.0  0.0  2.0  \n",
       "3    0.0  0.0  0.0  \n",
       "4    0.0  0.0  0.0  \n",
       "..   ...  ...  ...  \n",
       "115  0.0  0.0  1.0  \n",
       "116  0.0  0.0  1.0  \n",
       "117  0.0  0.0  0.0  \n",
       "118  0.0  0.0  0.0  \n",
       "119  0.0  0.0  1.0  \n",
       "\n",
       "[120 rows x 17 columns]"
      ]
     },
     "execution_count": 279,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plot_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Crea gráfico tsne"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 280,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[t-SNE] Computing 119 nearest neighbors...\n",
      "[t-SNE] Indexed 120 samples in 0.000s...\n",
      "[t-SNE] Computed neighbors for 120 samples in 0.011s...\n",
      "[t-SNE] Computed conditional probabilities for sample 120 / 120\n",
      "[t-SNE] Mean sigma: 1.722451\n",
      "[t-SNE] KL divergence after 250 iterations with early exaggeration: 49.608604\n",
      "[t-SNE] KL divergence after 700 iterations: 0.032179\n"
     ]
    }
   ],
   "source": [
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# reduce dimensionalidad con t-sne\n",
    "tsne = TSNE(n_components=2, verbose=1, perplexity=50, n_iter=1000, learning_rate=200)\n",
    "tsne_results = tsne.fit_transform(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 251,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [target_dimensions[i] for i in y_train_species]\n",
    "#['Setosa', 'Versicolor', 'Virginica']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 281,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x7f99a4611bb0>"
      ]
     },
     "execution_count": 281,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 445x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# visualiza con seaborn\n",
    "df_tsne = pd.DataFrame(tsne_results, columns=['x', 'y'])\n",
    "df_tsne['label'] = labels\n",
    "sns.lmplot(x='x', y='y', data=df_tsne, hue='label', fit_reg=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ejercicio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Reescriba y reentren la red en Pytorch.\n",
    "1. Investigue como extraer la capa oculta en Pytorch\n",
    "1. Haga un gráfico TSNE para los datos originales\n",
    "1. Haga un reducción ACP y haga el correspondiente gráfico\n",
    "\n",
    "¿Cuáles son sus conclusiones?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}