{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Practice with the Tensorflow 2 Functional API. \n", "\n", "> In this post, it will demonstrate how to build models with the Functional syntax. You'll build one using the Sequential API and see how you can do the same with the Functional API. Both will arrive at the same architecture and you can train and evaluate it as usual. This is the summary of lecture \"Custom Models, Layers and Loss functions with Tensorflow\" from DeepLearning.AI.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow, DeepLearning.AI]\n", "- image: images/fashion_mnist_siamese.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.utils import plot_model\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda\n", "from tensorflow.keras.optimizers import RMSprop\n", "from tensorflow.keras import backend as K\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import confusion_matrix\n", "from PIL import Image, ImageFont, ImageDraw\n", "import itertools\n", "import random" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1 - Comparing Functional API with Sequential API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the Data\n", "\n", "We will use mnist dataset for the check." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mnist = tf.keras.datasets.mnist\n", "\n", "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n", "X_train, X_test = X_train / 255.0, X_test / 255.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sequential API\n", "\n", "Here is how we use the `Sequential()` class to build a model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def build_model_with_sequential():\n", " # instantiate a Sequential class and linearly stack the layers of your model\n", " seq_model = tf.keras.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", " tf.keras.layers.Dense(128, activation=tf.nn.relu),\n", " tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n", " ])\n", " return seq_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Functional API\n", "\n", "And here is how you build the same model above with the functional syntax." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def build_model_wtih_functional():\n", " # instantiate the input Tensor\n", " input_layer = tf.keras.Input(shape=(28, 28))\n", " \n", " # stack the layers using the syntax: new_layer()(previous_layer)\n", " flatten_layer = tf.keras.layers.Flatten()(input_layer)\n", " first_dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)(flatten_layer)\n", " output_layer = tf.keras.layers.Dense(10, activation=tf.nn.softmax)(first_dense)\n", " \n", " # declare inputs and outputs\n", " func_model = Model(inputs=input_layer, outputs=output_layer)\n", " return func_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build the model and visualize the model graph" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = build_model_with_sequential()\n", "\n", "plot_model(model, show_shapes=True, show_layer_names=True, to_file='./image/sequential_model.png')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = build_model_wtih_functional()\n", "\n", "plot_model(model, show_shapes=True, show_layer_names=True, to_file='./image/functional_model.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that both model has same architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training the model\n", "\n", "Regardless if you built it with the Sequential or Functional API, you'll follow the same steps when training and evaluating your model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "1875/1875 [==============================] - 5s 2ms/step - loss: 0.2587 - accuracy: 0.9260\n", "Epoch 2/5\n", "1875/1875 [==============================] - 5s 3ms/step - loss: 0.1154 - accuracy: 0.9660\n", "Epoch 3/5\n", "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0790 - accuracy: 0.9762\n", "Epoch 4/5\n", "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0582 - accuracy: 0.9820\n", "Epoch 5/5\n", "1875/1875 [==============================] - 5s 3ms/step - loss: 0.0455 - accuracy: 0.9862\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.compile(optimizer=tf.optimizers.Adam(),\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])\n", "\n", "model.fit(X_train, y_train, epochs=5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "313/313 [==============================] - 1s 2ms/step - loss: 0.0805 - accuracy: 0.9743\n" ] }, { "data": { "text/plain": [ "[0.08054833114147186, 0.9743000268936157]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2 - Build a Multi-output Model\n", "\n", "In this section, we'll show how you can build models with more than one output. The dataset we will be working on is available from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Energy+efficiency). It is an Energy Efficiency dataset which uses the bulding features (e.g. wall area, roof area) as inputs and has two outputs: Cooling Load and Heating Load." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Utilities\n", "\n", "We define a few utilities for data conversion and visualization to make our code more neat." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def format_output(data):\n", " y1 = data.pop('Y1')\n", " y1 = np.array(y1)\n", " y2 = data.pop('Y2')\n", " y2 = np.array(y2)\n", " return y1, y2\n", "\n", "def norm(x):\n", " return (x - train_stats['mean']) / train_stats['std']\n", "\n", "def plot_diff(y_true, y_pred, title=''):\n", " plt.scatter(y_true, y_pred)\n", " plt.title(title)\n", " plt.xlabel('True Values')\n", " plt.ylabel('Predictions')\n", " plt.axis('equal')\n", " plt.axis('square')\n", " plt.xlim(plt.xlim())\n", " plt.ylim(plt.ylim())\n", " plt.plot([-100, 100], [-100, 100])\n", "\n", "def plot_metrics(history, metric_name, title, ylim=5):\n", " plt.title(title)\n", " plt.ylim(0, ylim)\n", " plt.plot(history.history[metric_name], color='blue', label=metric_name)\n", " plt.plot(history.history['val_' + metric_name], color='green', label='val_' + metric_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the Data\n", "\n", "We download the dataset and format it for training." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Specify data URI\n", "URI = './dataset/ENB2012_data.xlsx'\n", "\n", "# Use pandas excel reader\n", "df = pd.read_excel(URI)\n", "df.dropna(axis=1, inplace=True)\n", "df = df.sample(frac=1).reset_index(drop=True)\n", "\n", "# Split the data into train and test with 80 train / 20 test\n", "train, test = train_test_split(df, test_size=0.2)\n", "train_stats = train.describe()\n", "\n", "# Get Y1 and Y2 as the 2 outputs and format them as np arrays\n", "train_stats.pop('Y1')\n", "train_stats.pop('Y2')\n", "train_stats = train_stats.transpose()\n", "train_Y = format_output(train)\n", "test_Y = format_output(test)\n", "\n", "# Normalize the train and test data\n", "norm_train_X = norm(train)\n", "norm_test_X = norm(test)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
X1X2X3X4X5X6X7X8Y1Y2
00.74686.0245.0220.53.540.40214.1816.99
10.90563.5318.5122.57.050.00019.6829.60
20.64784.0343.0220.53.520.10515.1619.24
30.71710.5269.5220.53.520.10210.6413.67
40.90563.5318.5122.57.050.25231.6637.72
\n", "
" ], "text/plain": [ " X1 X2 X3 X4 X5 X6 X7 X8 Y1 Y2\n", "0 0.74 686.0 245.0 220.5 3.5 4 0.40 2 14.18 16.99\n", "1 0.90 563.5 318.5 122.5 7.0 5 0.00 0 19.68 29.60\n", "2 0.64 784.0 343.0 220.5 3.5 2 0.10 5 15.16 19.24\n", "3 0.71 710.5 269.5 220.5 3.5 2 0.10 2 10.64 13.67\n", "4 0.90 563.5 318.5 122.5 7.0 5 0.25 2 31.66 37.72" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build the Model\n", "\n", "Here is how we'll build the model using the functional syntax. Notice that we can specify a list of outputs (i.e. `[y1_output, y2_output]`) when we instantiate the `Model()` class." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# define model layers\n", "input_layer = Input(shape=(len(train.columns), ))\n", "first_dense = Dense(units=128, activation='relu')(input_layer)\n", "second_dense = Dense(units=128, activation='relu')(first_dense)\n", "\n", "# Y1 output will be fed directly from the second dense\n", "y1_output = Dense(units=1, name='y1_output')(second_dense)\n", "third_dense = Dense(units=64, activation='relu')(second_dense)\n", "\n", "# Y2 output will come via the third dense\n", "y2_output = Dense(units=1, name='y2_output')(third_dense)\n", "\n", "# Define the model with the input layer and a list of output layers\n", "model = Model(inputs=input_layer, outputs=[y1_output, y2_output])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_1\"\n", "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "input_2 (InputLayer) [(None, 8)] 0 \n", "__________________________________________________________________________________________________\n", "dense_4 (Dense) (None, 128) 1152 input_2[0][0] \n", "__________________________________________________________________________________________________\n", "dense_5 (Dense) (None, 128) 16512 dense_4[0][0] \n", "__________________________________________________________________________________________________\n", "dense_6 (Dense) (None, 64) 8256 dense_5[0][0] \n", "__________________________________________________________________________________________________\n", "y1_output (Dense) (None, 1) 129 dense_5[0][0] \n", "__________________________________________________________________________________________________\n", "y2_output (Dense) (None, 1) 65 dense_6[0][0] \n", "==================================================================================================\n", "Total params: 26,114\n", "Trainable params: 26,114\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot_model(model, show_shapes=True, show_layer_names=True, to_file='./image/multi_output_model.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure parameters\n", "\n", "We specify the optimizer as well as the loss and metrics for each output." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Specify the optimizer, and compile the model with loss functions for both outputs\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n", "model.compile(optimizer=optimizer,\n", " loss={\n", " 'y1_output':'mse',\n", " 'y2_output':'mse'\n", " },\n", " metrics={\n", " 'y1_output':tf.keras.metrics.RootMeanSquaredError(),\n", " 'y2_output':tf.keras.metrics.RootMeanSquaredError()\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the Model" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# train the model for 500 epochs\n", "history = model.fit(norm_train_X, train_Y,\n", " epochs=500, batch_size=10, validation_data=(norm_test_X, test_Y), verbose=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate the Model and Plot Metrics" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5/5 [==============================] - 0s 4ms/step - loss: 0.6642 - y1_output_loss: 0.1839 - y2_output_loss: 0.4802 - y1_output_root_mean_squared_error: 0.4289 - y2_output_root_mean_squared_error: 0.6930\n", "Loss = 0.6641601920127869, Y1_loss = 0.183916375041008, Y1_mse = 0.42885470390319824, Y2_loss = 0.4802437722682953, Y2_mse = 0.6929962038993835\n" ] } ], "source": [ "# Test the model and print loss and mse for both outputs\n", "loss, Y1_loss, Y2_loss, Y1_rmse, Y2_rmse = model.evaluate(x=norm_test_X, y=test_Y)\n", "print(\"Loss = {}, Y1_loss = {}, Y1_mse = {}, Y2_loss = {}, Y2_mse = {}\".format(loss, Y1_loss, Y1_rmse, Y2_loss, Y2_rmse))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQkAAAEWCAYAAAB16GIqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAd0klEQVR4nO3dfZRV9X3v8fd3hgkMRhhRUDIG8RKjIaCgEzWSpoqxGmPMSNWmCb2mtZqbZW5NYkkx0iXeJtEbjKarq8sbjV5pJUaJMhJvLFfxIa2NJOCAA1HCTRTiCQI+jBgcYYb53j/2PuNhOGeffc6cfR4/r7VmzZx9nn67jR/2/j18f+buiIjk0lTpBohIdVNIiEgkhYSIRFJIiEgkhYSIRFJIiEgkhYSIRFJISFHMbJmZ3TXs2B+b2Wtm9kdmtsrMXjUzTcSpcabJVFIMMzsc2AT8hbs/amZjgOeAbwM/Bz4GvAp0ubtVrqUyUgoJKZqZXQJ8B5gBLAJmufsnM57/ALBFIVHbFBIyImb2Y+A9wBxgtrtvy3hOIVEHRlW6AVLzrgJ+A1yXGRBSP9RxKSPi7jsI+h42VbotkgyFhIhE0u2GlJyZGTCaoK+CcOTD3X1vRRsmRdGVhCThGKCPd29B+oDNlWuOjIRGN0Qkkq4kRCSSQkJEIikkRCSSQkJEItXEEOgRRxzhU6dOrXQzROrK63v2kert49Axo3jtpRdedfeJ2V5XEyExdepU1q5dW+lmiNSNZWu2ct2KjXz+hEncNv9kxrSM2prrtbrdEGkw6YCYGwbE6FHNka9XSIg0kEIDAhQSIg2jmIAAhYRIQyg2IEAhIVL3RhIQoJAQqWsjDQhQSIjUrVIEBCgkROpSOiBGj2ri8Rd2Mvfmp+jqThX1WQoJkTqTDogmg70DgwCkevu49sGeooJCISFSRzKvIAaHlYrp69/PklWF1/5RSIjUicw+iPQVxHC/7+0r+HMVEiJ1YHgnZXtba9bXvS/H8SgKCZEal20UY8G5x9PacuBoRmtLcLxQNbEKVEQO1tWdYvHKTfT29TN6VBPnzzhqaJizc3Y7AEtWbeb3vX28r62VBeceP3S8EAoJkRrU1Z1iwfIN9Ie9k3sHBvn7hzYxqrlpKAg6Z7cXFQrD6XZDpAYtXrlpKCDSih29yEchIVJjlq3ZSm9ff9bnihm9yEchIVJDMudBZFPM6EU+CgmRGpE5ivGtzhklG73IRx2XIlWoqzt1wMjEGdMOZ/m6lw8Y5hzV3FSS0Yt8Et/mz8yagbVAyt0vMLMJwH3AVOAl4FJ3fyPqMzo6OlyFcKVRLOrqYdkz2xj+X+b0yeNYcdUZRa/mjGJm69y9I9tz5bjduBp4PuPxQmC1ux8HrA4fiwjBFcQ9WQICoPftfYkERD6JhoSZHQ18CvhBxuHPAEvDv5cCnUm2QaSWLF65Kedz2998p4wteVfSVxLfA74OZK42OdLdtwOEvydle6OZXWlma81s7a5duxJupkh1yDW0CcmMXMSRWEiY2QXATndfV8z73f12d+9w946JE7NuLCTSUJIYuYgjydGNOcCFZnY+MAYYZ2b3ADvMbLK7bzezycDOBNsgUlPGtjTxdv/By7zHtjQlMnIRR2JXEu5+rbsf7e5Tgc8Cj7v7fGAlcFn4ssuAh5Jqg0gtWbZma9aAAJh3ytFlbs27KjFP4ibgfjO7HNgGXFKBNohUzPA5EAvOPZ49+waGZlJmKxjzxAuV65crS0i4+5PAk+HfrwFnl+N7RapNV3eKax/soa9/PxDUnkyv5px7wiQefyH73XcSazLi0rRskTJasmrzUECk9Q86o0c1lbyiVKkoJETKKNcVwd6BwZJXlCoVhYRIGeW6IkhfQXTObufGeTNpb2vFwuM3zptZsZEN0AIvkbJacO7xB1SUgoOvFEpVUapUFBIiZbRn38BQH8TegUHaE1y9WSoKCZEyKdXenOWmPgmRMqjVgABdSYiMWLbJUfBuOfvxrS309vXXZECAQkJkRHJNjsKgf3/QOdnb10+TccC+GLVEtxsiI5BrclQ6INIGHW59bEs5m1YyCgmREShkunQlp1aPhEJCZATGt7bEfm3b2PivrSbqkxApQrqzMqqS1HAJ15xOjEJCpEDD9+GM680CAqWaKCREImQb3sy2D2cclVzJORIKCZEcsg1vZj4uVCVXco6EOi5Fcsg2vFlsQBw2tqWq12dEUUiI5FDKIcvrP/3hkn1WuSVZUn+Mmf3CzDaY2SYzuyE8vtjMUma2Pvw5P6k2iIxEqfoQavkqApK9ktgLzHX3k4BZwHlmdnr43K3uPiv8+WmCbRApWrYqUYVqbWmu6asISLakvrv7H8KHLeFPjY4USyPKViWqtSX6P5k50yZUVVWpUkh0V/FwR/F1wAeAf3b3vzOzxcAXgN0Eu41fk21XcTO7ErgSYMqUKads3bo1sXaKxJFe7t1kwVqMbNrbWnl64dzyNqwEKraruLvvd/dZwNHAqWY2A7gNmEZwC7Id+G6O92qbP6kamfUgvvOnJ+Z8Xa2uz4hSltENd+8l2HfjPHffEYbHIHAHcGo52iBSrOEFYy7ueH9Vlr5PSpKjGxPNrC38uxX4BPBCuP9n2kXAxqTaIDJSuSpKVWPp+6QkOeNyMrA07JdoAu5394fN7F/NbBZBJ+ZLwBcTbINI0aJKzqU7I4dP2a71TspsEu24LJWOjg5fu3ZtpZshDaSWa1IWo2IdlyK1qNECIh8t8JKGlG11Z+fsdgVEFrrdkIYzfHUngPHuTL/pk8ex4qozGiogdLshkmHxyk0HrebM/Kfy1zt280jPK+VtVBVTSEhD6epO5S05NzAIN/xkU5laVP0UEtJQlqzaHOt1b7xdm6XmkqCQkIZSj9Omk6aQkIYSd9p0WwGl8uudQkIayoJzj6elySJf09JkLL6wtmtAlJJCQhrKnn0DkZWumwyWXHJSXU6vLpZCQhpG5kSpmy8+kZbmA68oWpqNWy6dpYAYRjMupSFkm0k5qrmpIRZojZRCQuperqnWnbPbFQox6HZD6prWYoycQkLqlgKiNBQSUpcUEKWjkJC6o4AorcQ6Ls1sDPAzYHT4PT929+vNbAJwHzCVoHzdpdlK6ovElVkbYnxrC719/QqIEqrEDl4LgdXufhywOnwsUpR0bYhUbx8O9Pb102Rw/oyjFBAlUokdvD4DLA2PLwU6k2qD1L9rH3zuoNoQgw63PralQi2qP4n2SZhZs5mtB3YCj7r7GuBId98OEP6elGQbpH4t6uqhr38w63Na7Vk6ldjBKxYzu9LM1prZ2l27diXWRqldy9Zsy/lcPW6SUyllmXHp7r1m9iRwHrDDzCa7+/Zwo56dOd5zO3A7BDUuy9FOqX6fv+PnPP2b1/O+rh43yamUsu/gBawELgtfdhnwUFJtkPpyzi1PxgqIJkPTrUuoEjt4/Ry438wuB7YBlyTYBqkTi7p62LJzT6zXfu60KQm3prEkFhLu/hwwO8vx14Czk/peqT9d3SmWPZO7/yHT2JYmvtk5M+EWNZZYtxtmNs3MRod/n2lmf5O+lRBJUld3iq/dv544nVKtLc18e96Jibep0cTtk3gA2G9mHwDuBI4FfphYq0QIOim/ct96IgpJDTlsbAs3zpupvogExL3dGHT3ATO7CPieu/+TmXUn2TBpbIu6emJ1UgLMP32KbjESFPdKot/M/pxgNOLh8JjKCUti7l3zu1ivO2xsiwIiYXFD4i+BjwLfcvcXzexY4J7kmiWNbn/MPWqv/7SqWict1u2Gu/8K+JuMxy8CNyXVKJHMDXxzaW1pUh9EGcQKCTObAywGjgnfYwRruP5Lck2TRrVszda8AdEE3KiRjLKI23F5J/BVYB2wP89rRYqWWTDmyHGjue+XvztodKOttYXFF35YVxFlEjck3nT3RxJtiTS8bBWldLVQeXFD4gkzWwI8SFBMBgB3fzaRVkndyawe1Ta2BXd4s69/aL+LPfsGVHKuSsUNidPC3x0ZxxyYW9rmSD1KV49KF4d54+3+oedSvX0sWL6B/kFXQFSpuKMbZyXdEKlfS1ZtPqh6VKb+QWf0qCYFRJWKu3ZjvJndki4CY2bfNbPxSTdO6kOcKlF7BwaZe/NTdHWnytAiKUTcyVR3AW8Bl4Y/u4H/nVSjpL60jY03OTfV28e1D/YoKKpM3JCY5u7Xu/tvw58bAM2RkLy6ulP84Z2B2K/v69/PklWbE2yRFCpuSPSZ2cfSD8LJVao0KnktWbWZ/jjLODOoiG11iTu68SWCKlPjCWZbvg58IalGSf2I+g++va2VVJbnVcS2usS6knD39eEmOycCM919trtvSLZpUg9y/QffHs6PaG05cDSjtaVZRWyrTGRImNn88PfXzOxrwF8Df53xOOq97zezJ8zseTPbZGZXh8cXm1nKzNaHP+eX6mSk+iw493hamuyAY+kg6Jzdzo3zZtLe1ooRBIcKx1SffLcbh4S/D83yXL4bzQHgGnd/1swOBdaZ2aPhc7e6+80FtFOqXOaMyvQsys7Z7ezZNzA0D2LvwODQFUQ6CDpntysUqlxkSLj798M/H3P3pzOfCzsvo967HUjv1PWWmT0P6H8NdWj4jMr0UObT/+9Vlq97WTMpa1zc0Y1/inksKzObSlA5e0146Mtm9pyZ3WVmh+V4j3bwqhHZZlT29e9XQNSJyCsJM/socAYwcVgfxDgg1v/Xzey9BIV0v+Luu83sNuAfCG5X/gH4LvBXw9+nHbxqR9QIhgKi9uW7kngP8F6CMDk042c3cHG+DzezFoKAWObuDwK4+45wj9BB4A7g1OKbL9Ug1wjG+8aPUUDUgXx9Ek8BT5nZ3e6+tZAPNjMjKFbzvLvfknF8cnpXceAiYGOBbZYq0tWdYs/eg2dUjhnVxNfPO6ECLZJSi9sn8YPMzXjM7DAzW5XnPXOAvwDmDhvu/I6Z9ZjZc8BZBBWvpAalOyx7+/oPOG7AxR1Ha9SiTsSdcXmEu/emH7j7G2Y2KeoN7v4fBP97Ge6n8Zsn1aqrO8U192/IWtXagQfWpeg4ZoKCog7EvZIYNLOhXVjN7Bjyz5OQOtXVnWLB8uwBkaaFWvUj7pXEdcB/mNlT4eOPA1cm0ySpdotXboq1aEsLtepD3MpU/2ZmJwOnE9xCfNXdX020ZVK1hvdB5KKFWvUh39qNE8LfJwNTgN8DKWBKeEwkKy3Uqh/5riSuAa4gmPA0nArhNqBla+KNhGuhVv3IN0/iivC3CuHK0L4Y0yeP49c7djMwmP117W2tCog6km9a9ryo59OzKKX+Dd8455GeV7jhJ5sOKI8Pus2oR/luNz4d/p5EsIbj8fDxWcCTBJv1SJ3LtrNWeol3riXiUj/y3W78JYCZPQxMT0+nNrPJwD8n3zyptGwBkUn1IOpf3MlUUzPWWwDsAD6YQHukiuQLCGkMcSdTPRmu1biXYFTjs8ATibVKKk4BIWlxJ1N92cwuIphpCXC7u69IrllSSQoIyRT3SgLgWeAtd3/MzMaa2aHu/lZSDZPKUEDIcLFCwsyuIFirMQGYRlCr8n8BZyfXNCmHzNGJ8a0t9Pb1KyDkAHE7Lq8iqA+xG8DdtxAMi0oNS9eDSPX24QRrMpoMzp9xlAJChsQNib3uvi/9wMxGoaXiNS9bAdtBh1sf21KhFkk1ihsST5nZN4BWMzsHWA78JLlmSTnkWsqtJd6SKW5I/B2wC+gBvkhQXWpRUo2S8hjf2pL1uJZ4S6a8HZdm1gQ85+4zCKpbx2Jm7wf+BTgKGCQYNv1HM5sA3AdMBV4CLnX3NwpvuozEsjVbh/ogMuvHaO2FDJc3JNx90Mw2mNkUd99WwGfn2ubvC8Bqd7/JzBYCCwmuVCRh6ZGM9E7e0yeP46/mTOXWx7Zo7YXkFHeexGRgk5n9AtiTPujuF+Z6Q8Q2f58BzgxftpRgoZhCImHDt+ID+O2uPzCquYmnF6osiOQWNyRuGMmXDNvm78j0OhB3356r6raZXUlYR3PKlCnZXiIFyDaS8c7AIEtWbdaVg0TKV09iDPDfgA8QdFre6e4H78QS/RnDt/mL9T5t81daKY1kSJHyXUksBfqBfwc+CUwHro774dm2+QN2pHfxCpec7yy82QIcVMvhrBMm8sQLuw56nCsgQCMZkl++kJju7jMBzOxO4BdxPzjXNn/ASuAy4Kbw90MFtViAg/sYUr193PPMu/3Kwx9no5EMiSNfSAzVJnP3gbi3CqH0Nn89ZrY+PPYNgnC438wuB7YBlxTyoRLI1scQR7MZg+4ayZDY8oXESWa2O/zbCGZc7g7/dncfl+uNEdv8gRaGjVixfQmD7rx406dK3BqpZ/nK12mVT5V6X1trZF9D1PtEChF3WrZUmQXnHk9rS2EZrj4IKUYhRWekiqT7EobXgsilXX0QUiSFRA3LrFSdrijV0mRDm/m2tbaw+MIPKxhkRBQSdUAl5yRJ6pOocQoISZpCooYpIKQcdLtRI7q6Uyxeuemgzsnpk8cpICRRCokqlysc0n79ym4e6XlFnZOSGN1uVLGu7hQLlm+IHNoc8GAYVCQpCokqtnjlpqHhzCha7i1JUkhUsagriEyaai1JUkjUAU21liQpJKrYYWOzl7zPNP/0Keq0lEQpJKrYJz50ZM7n2lpb+N6fzeKbnTPL2CJpRBoCrQLDy9AtOPd49uwbYPm6l5k+eRy9b+9j+5vvqFCMVIRCosLSw5zpUYxUbx9fuW89EEyUWnHVGZooJRWV2O2Gmd1lZjvNbGPGscVmljKz9eHP+Ul9fy3o6k7x1fvW5xzmTE+UEqmkJPsk7gbOy3L8VnefFf78NMHvr2qLunr46n3rI7dm10QpqQaJhYS7/wx4PanPr2Vd3SmWPbMtMiDSNFFKKq0SfRJfNrP/Cqwl2Cu0ITYLzuycNCNWQIAmSknllXsI9DZgGjCLYJ/Q7+Z6oZldaWZrzWztrl27ytS8ZKT3yEj19uEcuIt3PpooJZVW1pBw9x3uvt/dB4E7gFMjXnu7u3e4e8fEiRPL18gEFLtHhiZKSTUo6+1Genu/8OFFwMao19eLQvsVVLRWqkliIWFm9wJnAkeY2cvA9cCZZjaL4Jb8JeCLSX1/JWSbFNU5u72gPTIMeHrh3GQbKlKAxELC3f88y+E7k/q+Ssu2N+e1D/YAQb9CvuHONHVUSrXR2o0Sydbv0Ne/nyWrNtM5u53Pnz4l72e0NJk6KqXqaFp2ieTqd0j19nHswv/D+NboFZ1jW5r49rwT1Q8hVUchUSJR/Q5OdAGZOdMmsOyKjybUMpGR0e1GicTdm3NsSxPNFmy23mzG/NOnKCCkqulKokQ6Z7ezduvr3PPMtsjX9fUP8uJNnypTq0RGTiFRAl3dKa5b0cOeffknTGn0QmqNQmKEurpTLPjxBvr3x5trrdELqTXqkxihJas2xw4IQKMXUnMUEiPQ1Z2KPZNSpFbpdqNIi7p6WJank3K4dvVHSA3SlUQRCikak9ba0qz+CKlJupIo0KKunrzDnMNpVafUMoVEAYoJiPmnT9HeGFLTdLsRU1d3quCAmDNtggJCap6uJGLo6k4N7YURh24vpJ4oJPIoNCCaLHvRmFwFaUSqnUIiQno2ZSE+d9rBdSOiCtIoKKTaqU8iwnUregqaTZmrkzKqII1ItUuyxuVdwAXATnefER6bANwHTCWocXlpte670dWdirVg66UYKzpzFaTRxjtSC8q9zd9CYLW7HwesDh9XpVL+K59r5adWhEotKPc2f58BloZ/LwU6k/r+kYrzr/ycaRNifVa2gjSagSm1otx9Ekem990If0/K9cJK7+CVryblcZMOiV1RqnN2OzfOm0l7WytGMER647yZ6rSUmlC1oxvufjtwO0BHR0chyyRGbNmarfT29dNkB2/Jd8h7mvnWRYX/B945u12hIDWp3CGxI72Ll5lNBnaW+fvzWrZmK9et2MjcEyZx/oyjuPWxLZrbIA2t3CGxErgMuCn8/VCZvz9SZkDcNv9kRo9q5uKO91e6WSIVlVifRLjN38+B483sZTO7nCAczjGzLcA54eOqkC0gRKT82/wBnJ3UdxZLASGSW8PPuFRAiERr6JBQQIjk17AhoYAQiachQ0IBIRJfw4WEAkKkMA0VEgoIkcI1TEgoIESKU7VrN0ZieKm4M6YdzvJ1LysgRIpQdyGRrVTc8nUvM33yOAWESBHq7nYjW6k4gN639ykgRIpQVyERtYHv9jffKXNrROpD3YRE+jYjF5WKEylO3YRErtsMUKk4kZGom5CIqkmpUnEixauLkOjqTuV8rr2tVQEhMgI1HxJd3SkWLN9AtiKYus0QGbmaD4nFKzfRP7xaLdBsptsMkRKo6ZBIV7XOZtBdASFSAhWZcWlmLwFvAfuBAXfvKPQz0msxRo9qYu/A4EHPa8hTpDQqeSVxlrvPGklAzD1hEt/qnKHdsUQSVHNrN7Kt5hzV3HTAgi7tjyFSOuZe1s2xgi81exF4A3Dg++FuXcNfcyVwJcCUKVNO2bp1q5Z7iyTEzNbluqqv1O3GHHc/GfgkcJWZfXz4C9z9dnfvcPeOiRMnKiBEKqQiIeHuvw9/7wRWAKdGvf71PfsUECIVUvaQMLNDzOzQ9N/AnwAbo96T6u1TQIhUSCU6Lo8EVphZ+vt/6O7/FvWGQ8eMUkCIVEjZQ8LdfwucVMh7jplwiAJCpEIqMrpRKDPbBWyt0NcfAbxaoe9OSj2eE9TneZXrnI5x94nZnqiJkKgkM1tbzISvalaP5wT1eV7VcE41vXZDRJKnkBCRSAqJ/A6aDVoH6vGcoD7Pq+LnpD4JEYmkKwkRiaSQEJFICokMZnaXme00s40ZxyaY2aNmtiX8fVgl21goM3u/mT1hZs+b2SYzuzo8XrPnZWZjzOwXZrYhPKcbwuM1e05pZtZsZt1m9nD4uOLnpJA40N3AecOOLQRWu/txwOrwcS0ZAK5x9w8BpxOsup1ObZ/XXmCuu58EzALOM7PTqe1zSrsaeD7jceXPyd31k/EDTAU2ZjzeDEwO/54MbK50G0d4fg8B59TLeQFjgWeB02r9nICjCYJgLvBweKzi56QrifyOdPftAOHvSRVuT9HMbCowG1hDjZ9XeFm+HtgJPOruNX9OwPeArwOZRVsrfk4KiQZhZu8FHgC+4u67K92ekXL3/e4+i+Bf31PNbEaFmzQiZnYBsNPd11W6LcMpJPLbYWaTAcLfOyvcnoKZWQtBQCxz9wfDwzV/XgDu3gs8SdCXVMvnNAe4MKwk/yNgrpndQxWck0Iiv5XAZeHflxHc09cMCwp33Ak87+63ZDxVs+dlZhPNrC38uxX4BPACNXxO7n6tux/t7lOBzwKPu/t8quCcNOMyg5ndC5xJsDx3B3A90AXcD0wBtgGXuPvrFWpiwczsY8C/Az28e6/7DYJ+iZo8LzM7EVgKNBP8Q3e/u/8PMzucGj2nTGZ2JvC37n5BNZyTQkJEIul2Q0QiKSREJJJCQkQiKSREJJJCQkQiKSQagJkdbmbrw59XzCyV8fg9Jfj8xWZ247Bjs8zs+Tzv+duRfrckr+Z2FZfCuftrBKslMbPFwB/c/eb082Y2yt0HRvAV9wKPANdmHPss8MMRfKZUCV1JNCgzu9vMbjGzJ4D/OfxfdjPbGC4Iw8zmh/Ub1pvZ983sgJ2S3H0z0Gtmp2UcvhT4kZldYWa/DGs/PGBmY7O05Ukz6wj/PiKcmpxexLUkfP9zZvbF8PhkM/tZ2J6NZvZHpf2/jmRSSDS2DwKfcPdrcr3AzD4E/BnBTvCzgP3A57O89F6CqwfC2g6vufsW4EF3/4gHtR+eBy4voH2XA2+6+0eAjwBXmNmxwOeAVWF7TgLWF/CZUiDdbjS25e6+P89rzgZOAX4Z7t/aSvZFRj8C/tPMriEIi3vD4zPM7JtAG/BeYFUB7fsT4EQzuzh8PB44DvglcFe4cK3L3dcX8JlSIIVEY9uT8fcAB15Zjgl/G7DU3TP7Gw7i7r8LbxP+GPhT4KPhU3cDne6+wcy+QLA2ZrjM7x6TcdyA/+7uBwWLmX0c+BTwr2a2xN3/Jap9UjzdbkjaS8DJAGZ2MnBseHw1cLGZTQqfm2Bmx+T4jHuBW4HfuPvL4bFDge3hv/rZblPS331K+PfFGcdXAV8K34uZfdDMDgm/f6e730GwwvXkQk5UCqOQkLQHgAlhtacvAb8GcPdfAYuA/2tmzwGPEpRRy2Y58GGCW4+0vydYcfoowXLubG4mCIP/JFiBm/YD4FfAsxYUJ/4+wdXvmcB6M+smuGr5x0JOVAqjVaAiEklXEiISSSEhIpEUEiISSSEhIpEUEiISSSEhIpEUEiIS6f8DiJpwXrk8cHUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot the loss and mse\n", "Y_pred = model.predict(norm_test_X)\n", "plot_diff(test_Y[0], Y_pred[0], title='Y1')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_diff(test_Y[1], Y_pred[1], title='Y2')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_metrics(history, metric_name='y1_output_root_mean_squared_error', title='Y1 RMSE', ylim=6)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_metrics(history, metric_name='y2_output_root_mean_squared_error', title='Y2 RMSE', ylim=7)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 3 - Implement a Siamese Network\n", "\n", "In this section, it will go through creating and training a multi-input model. You will build a basic Siamese Network to find the similarity or dissimilarity between items of clothing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the Dataset\n", "\n", "First define a few utilities for preparing and visualizing your dataset." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def create_pairs(x, digit_indices):\n", " '''\n", " Positive and negative pair creation.\n", " Alternates between positive and negative pairs.\n", " '''\n", " pairs = []\n", " labels = []\n", " n = min([len(digit_indices[d]) for d in range(10)]) - 1\n", " \n", " for d in range(10):\n", " for i in range(n):\n", " z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]\n", " pairs += [[x[z1], x[z2]]]\n", " inc = random.randrange(1, 10)\n", " dn = (d + inc) % 10\n", " z1, z2 = digit_indices[d][i], digit_indices[dn][i]\n", " pairs += [[x[z1], x[z2]]]\n", " labels += [1, 0]\n", " \n", " return np.array(pairs), np.array(labels)\n", "\n", "def create_pairs_on_set(images, labels):\n", " digit_indices = [np.where(labels == i)[0] for i in range(10)]\n", " pairs, y = create_pairs(images, digit_indices)\n", " y = y.astype('float32')\n", " return pairs, y\n", "\n", "def show_image(image):\n", " plt.figure()\n", " plt.imshow(image)\n", " plt.colorbar()\n", " plt.grid(False)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can now download and prepare our train and test sets. You will also create pairs of images that will go into the multi-input model." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# load the dataset\n", "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()\n", "\n", "# prepare train and test sets\n", "train_images = train_images.astype('float32')\n", "test_images = test_images.astype('float32')\n", "\n", "# normalize values\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0\n", "\n", "# create pairs on train and test sets\n", "train_pairs, train_y = create_pairs_on_set(train_images, train_labels)\n", "test_pairs, test_y = create_pairs_on_set(test_images, test_labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see a sample pair of images below." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "# array index\n", "this_pair = 8\n", "\n", "# show images at this index\n", "show_image(test_pairs[this_pair][0])\n", "show_image(test_pairs[this_pair][1])\n", "\n", "# print the label for this pair\n", "print(test_y[this_pair])" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# print other pairs\n", "\n", "show_image(train_pairs[:,0][0])\n", "show_image(train_pairs[:,0][1])\n", "\n", "show_image(train_pairs[:,1][0])\n", "show_image(train_pairs[:,1][1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build the Model\n", "\n", "Next, you'll define some utilities for building our model." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def initialize_base_network():\n", " input = Input(shape=(28,28,), name=\"base_input\")\n", " x = Flatten(name=\"flatten_input\")(input)\n", " x = Dense(128, activation='relu', name=\"first_base_dense\")(x)\n", " x = Dropout(0.1, name=\"first_dropout\")(x)\n", " x = Dense(128, activation='relu', name=\"second_base_dense\")(x)\n", " x = Dropout(0.1, name=\"second_dropout\")(x)\n", " x = Dense(128, activation='relu', name=\"third_base_dense\")(x)\n", "\n", " return Model(inputs=input, outputs=x)\n", "\n", "\n", "def euclidean_distance(vects):\n", " x, y = vects\n", " sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)\n", " return K.sqrt(K.maximum(sum_square, K.epsilon()))\n", "\n", "\n", "def eucl_dist_output_shape(shapes):\n", " shape1, shape2 = shapes\n", " return (shape1[0], 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how our base network looks. This is where the two inputs will pass through to generate an output vector." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "base_network = initialize_base_network()\n", "plot_model(base_network, show_shapes=True, show_layer_names=True, to_file='./image/base-siamese-model.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now build the Siamese network. The plot will show two inputs going to the base network." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create the left input and point to the base network\n", "input_a = Input(shape=(28,28,), name=\"left_input\")\n", "vect_output_a = base_network(input_a)\n", "\n", "# create the right input and point to the base network\n", "input_b = Input(shape=(28,28,), name=\"right_input\")\n", "vect_output_b = base_network(input_b)\n", "\n", "# measure the similarity of the two vector outputs\n", "output = Lambda(euclidean_distance, name=\"output_layer\", output_shape=eucl_dist_output_shape)([vect_output_a, vect_output_b])\n", "\n", "# specify the inputs and output of the model\n", "model = Model([input_a, input_b], output)\n", "\n", "# plot model graph\n", "plot_model(model, show_shapes=True, show_layer_names=True, to_file='./image/outer-siamese-model.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the Model\n", "\n", "You can now define the custom loss for our network and start training." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "def contrastive_loss_with_margin(margin):\n", " def contrastive_loss(y_true, y_pred):\n", " '''\n", " Contrastive loss from Hadsell-et-al.'06\n", " http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf\n", " '''\n", " square_pred = K.square(y_pred)\n", " margin_square = K.square(K.maximum(margin - y_pred, 0))\n", " return (y_true * square_pred + (1 - y_true) * margin_square)\n", " return contrastive_loss" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "938/938 [==============================] - 8s 8ms/step - loss: 0.1115 - val_loss: 0.0827\n", "Epoch 2/20\n", "938/938 [==============================] - 8s 8ms/step - loss: 0.0791 - val_loss: 0.0760\n", "Epoch 3/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0703 - val_loss: 0.0707\n", "Epoch 4/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0658 - val_loss: 0.0664\n", "Epoch 5/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0628 - val_loss: 0.0651\n", "Epoch 6/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0606 - val_loss: 0.0653\n", "Epoch 7/20\n", "938/938 [==============================] - 8s 8ms/step - loss: 0.0590 - val_loss: 0.0654\n", "Epoch 8/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0579 - val_loss: 0.0678\n", "Epoch 9/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0565 - val_loss: 0.0632\n", "Epoch 10/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0559 - val_loss: 0.0636\n", "Epoch 11/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0552 - val_loss: 0.0647\n", "Epoch 12/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0539 - val_loss: 0.0627\n", "Epoch 13/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0536 - val_loss: 0.0643\n", "Epoch 14/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0525 - val_loss: 0.0654\n", "Epoch 15/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0526 - val_loss: 0.0638\n", "Epoch 16/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0521 - val_loss: 0.0634\n", "Epoch 17/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0514 - val_loss: 0.0641\n", "Epoch 18/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0508 - val_loss: 0.0639\n", "Epoch 19/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0506 - val_loss: 0.0632\n", "Epoch 20/20\n", "938/938 [==============================] - 7s 8ms/step - loss: 0.0497 - val_loss: 0.0639\n" ] } ], "source": [ "model.compile(loss=contrastive_loss_with_margin(margin=1), optimizer=RMSprop())\n", "history = model.fit([train_pairs[:,0], train_pairs[:,1]], train_y, epochs=20, batch_size=128, validation_data=([test_pairs[:,0], test_pairs[:,1]], test_y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Evaluation\n", "\n", "As usual, you can evaluate our model by computing the accuracy and observing the metrics during training." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "def compute_accuracy(y_true, y_pred):\n", " '''\n", " Compute classification accuracy with a fixed threshold on distances.\n", " '''\n", " pred = y_pred.ravel() < 0.5\n", " return np.mean(pred == y_true)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "625/625 [==============================] - 1s 2ms/step - loss: 0.0639\n" ] } ], "source": [ "loss = model.evaluate(x=[test_pairs[:,0],test_pairs[:,1]], y=test_y)\n", "\n", "y_pred_train = model.predict([train_pairs[:,0], train_pairs[:,1]])\n", "train_accuracy = compute_accuracy(train_y, y_pred_train)\n", "\n", "y_pred_test = model.predict([test_pairs[:,0], test_pairs[:,1]])\n", "test_accuracy = compute_accuracy(test_y, y_pred_test)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss = 0.06388980895280838, Train Accuracy = 0.9388898149691616 Test Accuracy = 0.9131631631631631\n" ] } ], "source": [ "print(\"Loss = {}, Train Accuracy = {} Test Accuracy = {}\".format(loss, train_accuracy, test_accuracy))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def plot_metrics(history, metric_name, title, ylim=5):\n", " plt.title(title)\n", " plt.ylim(0,ylim)\n", " plt.plot(history.history[metric_name],color='blue',label=metric_name)\n", " plt.plot(history.history['val_' + metric_name],color='green',label='val_' + metric_name)\n", " plt.legend()\n", "\n", "plot_metrics(history, metric_name='loss', title=\"Loss\", ylim=0.2)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# Matplotlib config\n", "def visualize_images():\n", " plt.rc('image', cmap='gray_r')\n", " plt.rc('grid', linewidth=0)\n", " plt.rc('xtick', top=False, bottom=False, labelsize='large')\n", " plt.rc('ytick', left=False, right=False, labelsize='large')\n", " plt.rc('axes', facecolor='F8F8F8', titlesize=\"large\", edgecolor='white')\n", " plt.rc('text', color='a8151a')\n", " plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts\n", "\n", "\n", "# utility to display a row of digits with their predictions\n", "def display_images(left, right, predictions, labels, title, n):\n", " plt.figure(figsize=(17,3))\n", " plt.title(title)\n", " plt.yticks([])\n", " plt.xticks([])\n", " plt.grid(None)\n", " left = np.reshape(left, [n, 28, 28])\n", " left = np.swapaxes(left, 0, 1)\n", " left = np.reshape(left, [28, 28*n])\n", " plt.imshow(left)\n", " plt.figure(figsize=(17,3))\n", " plt.yticks([])\n", " plt.xticks([28*x+14 for x in range(n)], predictions)\n", " for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):\n", " if predictions[i] > 0.5: t.set_color('red') # bad predictions in red\n", " plt.grid(None)\n", " right = np.reshape(right, [n, 28, 28])\n", " right = np.swapaxes(right, 0, 1)\n", " right = np.reshape(right, [28, 28*n])\n", " plt.imshow(right)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see sample results for 10 pairs of items below." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_pred_train = np.squeeze(y_pred_train)\n", "indexes = np.random.choice(len(y_pred_train), size=10)\n", "display_images(train_pairs[:, 0][indexes], train_pairs[:, 1][indexes], y_pred_train[indexes], train_y[indexes], \"clothes and their dissimilarity\", 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Application - Multiple Output Models using the Keras Functional API\n", "\n", "In this section, we will use the Keras Functional API to train a model to predict two outputs, and it will use the **[Wine Quality Dataset](https://archive.ics.uci.edu/ml/datasets/Wine+Quality)** from the **UCI machine learning repository**. It has separate datasets for red wine and white wine.\n", "\n", "Normally, the wines are classified into one of the quality ratings specified in the attributes. In this exercise, you will combine the two datasets to predict the wine quality and whether the wine is red or white solely from the attributes. \n", "\n", "You will model wine quality estimations as a regression problem and wine type detection as a binary classification problem." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Dataset\n", "\n", "You will now load the dataset from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php)." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# URL of the white wine dataset\n", "URI = './dataset/winequality-white.csv'\n", "\n", "# load the dataset from the URL\n", "white_df = pd.read_csv(URI, sep=\";\")\n", "\n", "# fill the `is_red` column with zeros.\n", "white_df[\"is_red\"] = 0\n", "\n", "# keep only the first of duplicate items\n", "white_df = white_df.drop_duplicates(keep='first')" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.8\n", "9.1\n" ] } ], "source": [ "print(white_df.alcohol[0])\n", "print(white_df.alcohol[100])" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "# URL of the red wine dataset\n", "URI = './dataset/winequality-red.csv'\n", "\n", "# load the dataset from the URL\n", "red_df = pd.read_csv(URI, sep=\";\")\n", "\n", "# fill the `is_red` column with ones.\n", "red_df[\"is_red\"] = 1\n", "\n", "# keep only the first of duplicate items\n", "red_df = red_df.drop_duplicates(keep='first')" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.4\n", "10.2\n" ] } ], "source": [ "print(red_df.alcohol[0])\n", "print(red_df.alcohol[100])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Concatenate the datasets\n", "\n", "Next, concatenate the red and white wine dataframes." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "df = pd.concat([red_df, white_df], ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.4" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.alcohol[0]" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.5" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.alcohol[100]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a real-world scenario, you should shuffle the data." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "df = df.iloc[np.random.permutation(len(df))]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will chart the quality of the wines." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQk0lEQVR4nO3df6zddX3H8edL6rAUmShyw1q2sqQxAs1QbhgbCbmMTasYwWUmJUxgc6khuOjWZCn7xy1LE5aM/ZANsg4cNSJNh5KSIU7CdudMRCyKKT8kdFKxlFEdiJQZtPjeH/eLXtrT9vbcH+fc83k+kpNz7ud8P9/zfvec+7rf+7nfc5qqQpLUhtcMugBJ0sIx9CWpIYa+JDXE0Jekhhj6ktSQJYMu4EhOOumkWrlyZV9zX3zxRZYtWza3BQ3IqPQyKn2AvQyrUelltn088MAD36uqNx84PvShv3LlSrZv397X3MnJSSYmJua2oAEZlV5GpQ+wl2E1Kr3Mto8k3+417vKOJDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1ZOjfkSsNqx1PPc+VG+7qa+6uay+a42qkmfFIX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1JAjhn6SU5P8R5JHkzyc5CPd+BuT3JPk8e76xGlzrkmyM8ljSd45bfzsJDu6+z6eJPPTliSpl5kc6e8H1lfVW4FzgauTnA5sAO6tqlXAvd3XdPetBc4A1gA3JDmm29eNwDpgVXdZM4e9SJKO4IihX1VPV9XXutsvAI8Cy4GLgc3dZpuBS7rbFwNbquqlqnoC2Amck+QU4ISq+nJVFfDJaXMkSQvgqNb0k6wE3gZ8BRirqqdh6gcDcHK32XLgO9Om7e7Glne3DxyXJC2QJTPdMMnxwGeAj1bVDw6zHN/rjjrMeK/HWsfUMhBjY2NMTk7OtMxX2bdvX99zh82o9DIqfQCMLYX1q/f3NXfY/g1G6XkZlV7mq48ZhX6S1zIV+LdW1We74WeSnFJVT3dLN3u78d3AqdOmrwD2dOMreowfpKo2AZsAxsfHa2JiYmbdHGBycpJ+5w6bUellVPoAuP7WbVy3Y8bHTa+y67KJuS1mlkbpeRmVXuarj5mcvRPgZuDRqvrraXfdCVzR3b4C2DZtfG2SY5OcxtQfbO/vloBeSHJut8/Lp82RJC2AmRymnAd8ANiR5MFu7E+Ba4GtST4IPAm8H6CqHk6yFXiEqTN/rq6ql7t5VwG3AEuBu7uLJGmBHDH0q+pL9F6PB7jwEHM2Aht7jG8HzjyaAiVJc8d35EpSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JD+vtfnaUDrNxw14y2W796P1cesO2uay+aj5Ik9eCRviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhwx9JN8IsneJA9NG/uzJE8lebC7vHvafdck2ZnksSTvnDZ+dpId3X0fT5K5b0eSdDgzOdK/BVjTY/xvquqs7vI5gCSnA2uBM7o5NyQ5ptv+RmAdsKq79NqnJGkeHTH0q+qLwLMz3N/FwJaqeqmqngB2AuckOQU4oaq+XFUFfBK4pM+aJUl9WjKLuR9OcjmwHVhfVc8By4H7pm2zuxv7cXf7wPGekqxj6rcCxsbGmJyc7KvAffv29T132Ax7L+tX75/RdmNLD952mPs6nF69zNSw9Tzsr6+jMSq9zFcf/Yb+jcBfANVdXwf8PtBrnb4OM95TVW0CNgGMj4/XxMREX0VOTk7S79xhM+y9XLnhrhltt371fq7b8eqX3a7LJuahovl3/a3bDuplpoat52F/fR2NUellvvro6+ydqnqmql6uqp8A/wSc0921Gzh12qYrgD3d+Ioe45KkBdRX6Hdr9K94H/DKmT13AmuTHJvkNKb+YHt/VT0NvJDk3O6sncuBbbOoW5LUhyP+bprkNmACOCnJbuBjwESSs5haotkFfAigqh5OshV4BNgPXF1VL3e7uoqpM4GWAnd3F0nSAjpi6FfVpT2Gbz7M9huBjT3GtwNnHlV1kqQ55TtyJakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhwx9JN8IsneJA9NG3tjknuSPN5dnzjtvmuS7EzyWJJ3Ths/O8mO7r6PJ8nctyNJOpyZHOnfAqw5YGwDcG9VrQLu7b4myenAWuCMbs4NSY7p5twIrANWdZcD9ylJmmdHDP2q+iLw7AHDFwObu9ubgUumjW+pqpeq6glgJ3BOklOAE6rqy1VVwCenzZEkLZB+1/THquppgO765G58OfCdadvt7saWd7cPHJckLaAlc7y/Xuv0dZjx3jtJ1jG1FMTY2BiTk5N9FbNv376+5w6bYe9l/er9M9pubOnB2w5zX4fTq5eZGraeh/31dTRGpZf56qPf0H8mySlV9XS3dLO3G98NnDptuxXAnm58RY/xnqpqE7AJYHx8vCYmJvoqcnJykn7nDpth7+XKDXfNaLv1q/dz3Y5Xv+x2XTYxDxXNv+tv3XZQLzM1bD0P++vraIxKL/PVR7/LO3cCV3S3rwC2TRtfm+TYJKcx9Qfb+7sloBeSnNudtXP5tDmSpAVyxMOUJLcBE8BJSXYDHwOuBbYm+SDwJPB+gKp6OMlW4BFgP3B1Vb3c7eoqps4EWgrc3V0kSQvoiKFfVZce4q4LD7H9RmBjj/HtwJlHVZ0kaU75jlxJaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhoy15+9I2nIrezxkRnrV++f0Udp7Lr2ovkoSQvII31JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1ZFahn2RXkh1JHkyyvRt7Y5J7kjzeXZ84bftrkuxM8liSd862eEnS0ZmLI/0Lquqsqhrvvt4A3FtVq4B7u69JcjqwFjgDWAPckOSYOXh8SdIMzcfyzsXA5u72ZuCSaeNbquqlqnoC2AmcMw+PL0k6hFRV/5OTJ4DngAL+sao2Jfl+Vb1h2jbPVdWJSf4euK+qPtWN3wzcXVW399jvOmAdwNjY2Nlbtmzpq759+/Zx/PHH9zV32Ax7Lzueen5G240thWd++Oqx1ct/fh4qmn97n33+oF5mapA993quej0vvSyG52rYv1dmarZ9XHDBBQ9MW4H5qSWzqgrOq6o9SU4G7knyzcNsmx5jPX/iVNUmYBPA+Ph4TUxM9FXc5OQk/c4dNsPey5Ub7prRdutX7+e6Ha9+2e26bGIeKpp/19+67aBeZmqQPfd6rno9L70shudq2L9XZmq++pjV8k5V7emu9wJ3MLVc80ySUwC6673d5ruBU6dNXwHsmc3jS5KOTt+hn2RZkte/cht4B/AQcCdwRbfZFcC27vadwNokxyY5DVgF3N/v40uSjt5slnfGgDuSvLKfT1fV55N8Fdia5IPAk8D7Aarq4SRbgUeA/cDVVfXyrKqXJB2VvkO/qr4F/EqP8f8FLjzEnI3Axn4fU5I0O74jV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDen7P0bX/Fm54a6Dxtav3s+VPcYPtOvai+ajJEkjwiN9SWqIoS9JDTH0Jakhhr4kNcTQl6SGePaOpKHX64y2Q+l1pptntf2MR/qS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktSQBX9HbpI1wN8BxwA3VdW18/VYO556fkafQd+L7+CTNIoW9Eg/yTHAPwDvAk4HLk1y+kLWIEktW+gj/XOAnVX1LYAkW4CLgUcWuA5JmndH85lBB7plzbI5rORnUlXzsuOeD5b8DrCmqv6g+/oDwK9W1YcP2G4dsK778i3AY30+5EnA9/qcO2xGpZdR6QPsZViNSi+z7eOXqurNBw4u9JF+eowd9FOnqjYBm2b9YMn2qhqf7X6Gwaj0Mip9gL0Mq1HpZb76WOizd3YDp077egWwZ4FrkKRmLXTofxVYleS0JD8HrAXuXOAaJKlZC7q8U1X7k3wY+DemTtn8RFU9PI8POesloiEyKr2MSh9gL8NqVHqZlz4W9A+5kqTB8h25ktQQQ1+SGjJyoZ/kdUnuT/KNJA8n+fNB1zRbSY5J8vUk/zroWmYjya4kO5I8mGT7oOuZjSRvSHJ7km8meTTJrw26pqOV5C3dc/HK5QdJPjrouvqV5I+67/mHktyW5HWDrqlfST7S9fHwXD8nI7emnyTAsqral+S1wJeAj1TVfQMurW9J/hgYB06oqvcMup5+JdkFjFfVon/jTJLNwH9V1U3dmWjHVdX3B1xW37qPSHmKqTdLfnvQ9RytJMuZ+l4/vap+mGQr8LmqumWwlR29JGcCW5j6BIMfAZ8Hrqqqx+di/yN3pF9T9nVfvra7LNqfbElWABcBNw26Fk1JcgJwPnAzQFX9aDEHfudC4L8XY+BPswRYmmQJcByL9z1AbwXuq6r/q6r9wH8C75urnY9c6MNPl0MeBPYC91TVVwZc0mz8LfAnwE8GXMdcKOALSR7oPmpjsfpl4LvAP3fLbjclmZ8PSlk4a4HbBl1Ev6rqKeCvgCeBp4Hnq+oLg62qbw8B5yd5U5LjgHfz6je1zspIhn5VvVxVZzH1jt9zul+XFp0k7wH2VtUDg65ljpxXVW9n6lNWr05y/qAL6tMS4O3AjVX1NuBFYMNgS+pftzz1XuBfBl1Lv5KcyNSHN54G/AKwLMnvDraq/lTVo8BfAvcwtbTzDWD/XO1/JEP/Fd2v3JPAmsFW0rfzgPd2a+FbgN9I8qnBltS/qtrTXe8F7mBqzXIx2g3snvYb5O1M/RBYrN4FfK2qnhl0IbPwm8ATVfXdqvox8Fng1wdcU9+q6uaqentVnQ88C8zJej6MYOgneXOSN3S3lzL1YvjmQIvqU1VdU1UrqmolU79+/3tVLcqjlyTLkrz+ldvAO5j6NXbRqar/Ab6T5C3d0IUs7o8Hv5RFvLTTeRI4N8lx3ckcFwKPDrimviU5ubv+ReC3mcPnZ8H/56wFcAqwuTsb4TXA1qpa1Kc6jogx4I6p70eWAJ+uqs8PtqRZ+UPg1m5p5FvA7w24nr50a8a/BXxo0LXMRlV9JcntwNeYWgr5Oov74xg+k+RNwI+Bq6vqubna8cidsilJOrSRW96RJB2aoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5Ia8v9wPc08XnoF+AAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df['quality'].hist(bins=20);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Imbalanced data\n", "You can see from the plot above that the wine quality dataset is imbalanced. \n", "\n", "- Since there are very few observations with quality equal to 3, 4, 8 and 9, you can drop these observations from your dataset. \n", "- You can do this by removing data belonging to all classes except those > 4 and < 8." ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "# get data with wine quality greater than 4 and less than 8\n", "df = df[(df['quality'] > 4) & (df['quality'] < 8 )]\n", "\n", "# reset index and drop the old one\n", "df = df.reset_index(drop=True)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10.8\n", "11.2\n" ] } ], "source": [ "print(df.alcohol[0])\n", "print(df.alcohol[100])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can plot again to see the new range of data and quality" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQaklEQVR4nO3df6zddX3H8edrRQkW+ZXOO9IywaQx48dE2zDUzN2GRapuK/vDpIQIZCydBpOZkEWYyTRZmuAfbAk4yDpxQGQ2zB8rEdlGGDdmU8TiwPJDtEonpUinIFBiMLD3/jjfbsfLae85p/ec2/p5PpKT8z2f7/fz/b6/Xz687rmfe863qSokSW34laUuQJI0PYa+JDXE0Jekhhj6ktQQQ1+SGnLUUhewkBUrVtSpp546Vt8XX3yR5cuXL25Bi8C6RmNdo7Gu0fyy1nX//ff/uKp+9VUrquqwfqxZs6bGdc8994zdd5KsazTWNRrrGs0va13A9hqQqU7vSFJDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQw772zBIh6sdTz7HpVfeMVbfXVe/b5GrkYbjO31JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDFgz9JKckuSfJo0keTvKnXftJSe5K8r3u+cS+Plcl2ZnksSTn97WvSbKjW3dtkkzmtCRJgwzzTv9l4Iqq+g3gXODyJKcDVwJ3V9Vq4O7uNd26jcAZwHrg+iTLun3dAGwCVneP9Yt4LpKkBSwY+lX1VFV9q1t+AXgUWAlsAG7uNrsZuKBb3gBsraqXqupxYCdwTpKTgeOq6utVVcAtfX0kSVMw0px+klOBtwLfAGaq6ino/WAA3tBtthJ4oq/b7q5tZbc8v12SNCVHDbthkmOBLwAfqarnDzIdP2hFHaR90LE20ZsGYmZmhrm5uWHL/AX79u0bu+8kWddoDte6Zo6BK856eay+kzyfw/V6WddoJlXXUKGf5DX0Av/Wqvpi1/x0kpOr6qlu6mZv174bOKWv+ypgT9e+akD7q1TVFmALwNq1a2t2dna4s5lnbm6OcftOknWN5nCt67pbt3HNjqHfN/2CXRfNLm4xfQ7X62Vdo5lUXcN8eifAjcCjVfVXfatuBy7pli8BtvW1b0xydJLT6P3B9r5uCuiFJOd2+7y4r48kaQqGeZvyTuADwI4kD3Rtfw5cDdyW5DLgh8D7Aarq4SS3AY/Q++TP5VX1StfvQ8BNwDHAnd1DkjQlC4Z+Vf07g+fjAc47QJ/NwOYB7duBM0cpUJK0ePxGriQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDVkvH/V+Qix48nnuPTKO8bqu+vq9y1yNZK09HynL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIQuGfpLPJNmb5KG+tk8keTLJA93jvX3rrkqyM8ljSc7va1+TZEe37tokWfzTkSQdzDDv9G8C1g9o/+uqOrt7fAUgyenARuCMrs/1SZZ1298AbAJWd49B+5QkTdCCoV9VXwWeGXJ/G4CtVfVSVT0O7ATOSXIycFxVfb2qCrgFuGDMmiVJY0ovgxfYKDkV+HJVndm9/gRwKfA8sB24oqqeTfIp4N6q+my33Y3AncAu4Oqq+t2u/beBj1bV7x3geJvo/VbAzMzMmq1bt451cnufeY6nfzZWV85aefx4HYewb98+jj322Intf1zWNRrH12isazSHWte6devur6q189uPGnN/NwB/CVT3fA3wR8Cgefo6SPtAVbUF2AKwdu3amp2dHavI627dxjU7xjvFXReNd8xhzM3NMe45TZJ1jcbxNRrrGs2k6hrr0ztV9XRVvVJV/wP8HXBOt2o3cErfpquAPV37qgHtkqQpGiv0uzn6/f4Q2P/JntuBjUmOTnIavT/Y3ldVTwEvJDm3+9TOxcC2Q6hbkjSGBX83TfI5YBZYkWQ38HFgNsnZ9KZodgF/AlBVDye5DXgEeBm4vKpe6Xb1IXqfBDqG3jz/nYt4HpKkISwY+lV14YDmGw+y/WZg84D27cCZI1UnSVpUfiNXkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqyIKhn+QzSfYmeaiv7aQkdyX5Xvd8Yt+6q5LsTPJYkvP72tck2dGtuzZJFv90JEkHM8w7/ZuA9fPargTurqrVwN3da5KcDmwEzuj6XJ9kWdfnBmATsLp7zN+nJGnCFgz9qvoq8My85g3Azd3yzcAFfe1bq+qlqnoc2Amck+Rk4Liq+npVFXBLXx9J0pSMO6c/U1VPAXTPb+jaVwJP9G23u2tb2S3Pb5ckTdFRi7y/QfP0dZD2wTtJNtGbCmJmZoa5ubmxipk5Bq446+Wx+o57zGHs27dvovsfl3WNxvE1GusazaTqGjf0n05yclU91U3d7O3adwOn9G23CtjTta8a0D5QVW0BtgCsXbu2Zmdnxyryulu3cc2O8U5x10XjHXMYc3NzjHtOk2Rdo3F8jca6RjOpusad3rkduKRbvgTY1te+McnRSU6j9wfb+7opoBeSnNt9aufivj6SpClZ8G1Kks8Bs8CKJLuBjwNXA7cluQz4IfB+gKp6OMltwCPAy8DlVfVKt6sP0fsk0DHAnd1DkjRFC4Z+VV14gFXnHWD7zcDmAe3bgTNHqk6StKj8Rq4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ1Z7HvvSJI6p155x9h9b1q/fBEr+X++05ekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDin0k+xKsiPJA0m2d20nJbkryfe65xP7tr8qyc4kjyU5/1CLlySNZjHe6a+rqrOram33+krg7qpaDdzdvSbJ6cBG4AxgPXB9kmWLcHxJ0pAmMb2zAbi5W74ZuKCvfWtVvVRVjwM7gXMmcHxJ0gGkqsbvnDwOPAsU8LdVtSXJT6vqhL5tnq2qE5N8Cri3qj7btd8I3FlVnx+w303AJoCZmZk1W7duHau+vc88x9M/G6srZ608fryOQ9i3bx/HHnvsxPY/LusajeNrNC3WtePJ58bue9rxyw6prnXr1t3fNwPzf44ae48976yqPUneANyV5DsH2TYD2gb+xKmqLcAWgLVr19bs7OxYxV136zau2THeKe66aLxjDmNubo5xz2mSrGs0jq/RtFjXpVfeMXbfm9Yvn0hdhzS9U1V7uue9wJfoTdc8neRkgO55b7f5buCUvu6rgD2HcnxJ0mjGDv0ky5O8fv8y8G7gIeB24JJus0uAbd3y7cDGJEcnOQ1YDdw37vElSaM7lOmdGeBLSfbv5x+q6p+TfBO4LcllwA+B9wNU1cNJbgMeAV4GLq+qVw6peknSSMYO/ar6AfCWAe0/Ac47QJ/NwOZxjylJOjR+I1eSGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDZl66CdZn+SxJDuTXDnt40tSy6Ya+kmWAX8DvAc4HbgwyenTrEGSWjbtd/rnADur6gdV9XNgK7BhyjVIUrOOmvLxVgJP9L3eDfzW/I2SbAI2dS/3JXlszOOtAH48Tsd8cswjDmfsuibMukbj+BqNdY1g3ScPua43DmqcduhnQFu9qqFqC7DlkA+WbK+qtYe6n8VmXaOxrtFY12haq2va0zu7gVP6Xq8C9ky5Bklq1rRD/5vA6iSnJXktsBG4fco1SFKzpjq9U1UvJ/kw8C/AMuAzVfXwBA95yFNEE2Jdo7Gu0VjXaJqqK1WvmlKXJP2S8hu5ktQQQ1+SGnLEhn6SXUl2JHkgyfYB65Pk2u52D99O8ra+dRO7FcQQdV3U1fPtJF9L8pZh+064rtkkz3XrH0jyF33rlvJ6/VlfTQ8leSXJScP0PcS6Tkjy+STfSfJokrfPW79U42uhupZqfC1U11KNr4Xqmvr4SvLmvmM+kOT5JB+Zt83kxldVHZEPYBew4iDr3wvcSe+7AecC3+jalwHfB94EvBZ4EDh9inW9AzixW37P/rqG6TvhumaBLw9oX9LrNW/b3wf+bUrX62bgj7vl1wInHCbja6G6lmp8LVTXUo2vg9a1VONr3vn/CHjjtMbXEftOfwgbgFuq517ghCQns8S3gqiqr1XVs93Le+l9V+FwdjjdOuNC4HOTPkiS44B3ATcCVNXPq+qn8zab+vgapq6lGF9DXq8DWdLrNc9Uxtc85wHfr6r/mtc+sfF1JId+Af+a5P70btsw36BbPqw8SPu06up3Gb2f5uP0nURdb0/yYJI7k5zRtR0W1yvJ64D1wBdG7TuGNwH/Dfx9kv9M8ukky+dtsxTja5i6+k1rfA1b17TH19DXa8rjq99GBv+gmdj4OpJD/51V9TZ6v8JenuRd89Yf6JYPQ90KYoJ19YpL1tH7n/Kjo/adUF3fovcr5luA64B/2l/qgH1N/XrR+9X7P6rqmTH6juoo4G3ADVX1VuBFYP7c6VKMr2Hq6hU33fE1TF1LMb6Gvl5Md3wBkN4XVP8A+MdBqwe0Lcr4OmJDv6r2dM97gS/R+7Wn34Fu+TDRW0EMURdJfhP4NLChqn4ySt9J1VVVz1fVvm75K8BrkqzgMLhenVe9I5rg9doN7K6qb3SvP08vPOZvM+3xNUxdSzG+FqxricbXUNerM83xtd97gG9V1dMD1k1sfB2RoZ9keZLX718G3g08NG+z24GLu7+Cnws8V1VPMcFbQQxTV5JfB74IfKCqvjviOU2yrl9Lkm75HHpj4ycs8fXq1h0P/A6wbdS+46iqHwFPJHlz13Qe8Mi8zaY+voapaynG15B1TX18Dfnfcerjq8/B/oYwufF1qH99XooHvbm6B7vHw8DHuvYPAh/slkPvH2z5PrADWNvX/73Ad7t1H5tyXZ8GngUe6B7bD9Z3inV9uFv3IL0/AL7jcLhe3etLga3D9F3E2s4GtgPfpjcVceJSj68h65r6+BqyrqmPr2HqWsLx9Tp6P/SO72ubyvjyNgyS1JAjcnpHkjQeQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ15H8BPPRzumSE188AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df['quality'].hist(bins=20);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train Test Split\n", "\n", "Next, you can split the datasets into training, test and validation datasets.\n", "- The data frame should be split 80:20 into `train` and `test` sets.\n", "- The resulting `train` should then be split 80:20 into `train` and `val` sets.\n", "- The `train_test_split` parameter `test_size` takes a float value that ranges between 0. and 1, and represents the proportion of the dataset that is allocated to the test set. The rest of the data is allocated to the training set." ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "# split df into 80:20 train and test sets\n", "train, test = train_test_split(df, test_size=0.2, random_state=1)\n", " \n", "# split train into 80:20 train and val sets\n", "train, val = train_test_split(train, test_size=0.2, random_state=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's where you can explore the training stats. You can pop the labels 'is_red' and 'quality' from the data as these will be used as the labels\n" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "train_stats = train.describe()\n", "train_stats.pop('is_red')\n", "train_stats.pop('quality')\n", "train_stats = train_stats.transpose()" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countmeanstdmin25%50%75%max
fixed acidity3155.07.2180511.3144343.800006.400007.00007.70000015.90000
volatile acidity3155.00.3431240.1673680.085000.230000.29000.4100001.33000
citric acid3155.00.3186590.1468920.000000.250000.31000.3900001.66000
residual sugar3155.05.0971954.6031850.700001.800002.70007.60000065.80000
chlorides3155.00.0575910.0373490.009000.038000.04700.0680000.61100
free sulfur dioxide3155.030.19841517.1073982.0000017.0000028.000041.000000128.00000
total sulfur dioxide3155.0114.67052356.8579066.0000075.00000116.0000155.000000303.00000
density3155.00.9946180.0030550.987110.992220.99480.9968451.03898
pH3155.03.2246850.1607492.740003.110003.22003.3300004.01000
sulphates3155.00.5360760.1470630.220000.440000.51000.6000001.62000
alcohol3155.010.5205021.1787848.000009.5000010.300011.30000014.90000
\n", "
" ], "text/plain": [ " count mean std min 25% \\\n", "fixed acidity 3155.0 7.218051 1.314434 3.80000 6.40000 \n", "volatile acidity 3155.0 0.343124 0.167368 0.08500 0.23000 \n", "citric acid 3155.0 0.318659 0.146892 0.00000 0.25000 \n", "residual sugar 3155.0 5.097195 4.603185 0.70000 1.80000 \n", "chlorides 3155.0 0.057591 0.037349 0.00900 0.03800 \n", "free sulfur dioxide 3155.0 30.198415 17.107398 2.00000 17.00000 \n", "total sulfur dioxide 3155.0 114.670523 56.857906 6.00000 75.00000 \n", "density 3155.0 0.994618 0.003055 0.98711 0.99222 \n", "pH 3155.0 3.224685 0.160749 2.74000 3.11000 \n", "sulphates 3155.0 0.536076 0.147063 0.22000 0.44000 \n", "alcohol 3155.0 10.520502 1.178784 8.00000 9.50000 \n", "\n", " 50% 75% max \n", "fixed acidity 7.0000 7.700000 15.90000 \n", "volatile acidity 0.2900 0.410000 1.33000 \n", "citric acid 0.3100 0.390000 1.66000 \n", "residual sugar 2.7000 7.600000 65.80000 \n", "chlorides 0.0470 0.068000 0.61100 \n", "free sulfur dioxide 28.0000 41.000000 128.00000 \n", "total sulfur dioxide 116.0000 155.000000 303.00000 \n", "density 0.9948 0.996845 1.03898 \n", "pH 3.2200 3.330000 4.01000 \n", "sulphates 0.5100 0.600000 1.62000 \n", "alcohol 10.3000 11.300000 14.90000 " ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_stats" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Get the labels\n", "\n", "The features and labels are currently in the same dataframe.\n", "- You will want to store the label columns `is_red` and `quality` separately from the feature columns. \n", "- The following function, `format_output`, gets these two columns from the dataframe (it's given to you).\n", "- `format_output` also formats the data into numpy arrays. \n", "- Please use the `format_output` and apply it to the `train`, `val` and `test` sets to get dataframes for the labels." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "def format_output(data):\n", " is_red = data.pop('is_red')\n", " is_red = np.array(is_red)\n", " quality = data.pop('quality')\n", " quality = np.array(quality)\n", " return (quality, is_red)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "# format the output of the train set\n", "train_Y = format_output(train)\n", "\n", "# format the output of the val set\n", "val_Y = format_output(val)\n", " \n", "# format the output of the test set\n", "test_Y = format_output(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that after you get the labels, the `train`, `val` and `test` dataframes no longer contain the label columns, and contain just the feature columns.\n", "- This is because you used `.pop` in the `format_output` function." ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcohol
2256.40.260.2510.700.04666.0179.00.996063.170.559.9
35577.30.220.4014.750.04244.5129.50.999803.360.419.1
38256.80.160.2910.400.04659.0143.00.995183.200.4010.8
17406.60.160.571.100.13058.0140.00.992703.120.399.3
12216.90.280.411.400.0166.055.00.988763.160.4013.4
\n", "
" ], "text/plain": [ " fixed acidity volatile acidity citric acid residual sugar chlorides \\\n", "225 6.4 0.26 0.25 10.70 0.046 \n", "3557 7.3 0.22 0.40 14.75 0.042 \n", "3825 6.8 0.16 0.29 10.40 0.046 \n", "1740 6.6 0.16 0.57 1.10 0.130 \n", "1221 6.9 0.28 0.41 1.40 0.016 \n", "\n", " free sulfur dioxide total sulfur dioxide density pH sulphates \\\n", "225 66.0 179.0 0.99606 3.17 0.55 \n", "3557 44.5 129.5 0.99980 3.36 0.41 \n", "3825 59.0 143.0 0.99518 3.20 0.40 \n", "1740 58.0 140.0 0.99270 3.12 0.39 \n", "1221 6.0 55.0 0.98876 3.16 0.40 \n", "\n", " alcohol \n", "225 9.9 \n", "3557 9.1 \n", "3825 10.8 \n", "1740 9.3 \n", "1221 13.4 " ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Normalize the data \n", "\n", "Next, you can normalize the data, x, using the formula:\n", "\n", "$$x_{norm} = \\frac{x - \\mu}{\\sigma}$$\n", "\n", "- The `norm` function is defined for you.\n", "- Please apply the `norm` function to normalize the dataframes that contains the feature columns of `train`, `val` and `test` sets." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "def norm(x):\n", " return (x - train_stats['mean']) / train_stats['std']" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "# normalize the train set\n", "norm_train_X = norm(train)\n", " \n", "# normalize the val set\n", "norm_val_X = norm(val)\n", " \n", "# normalize the test set\n", "norm_test_X = norm(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the Model\n", "\n", "Define the model using the functional API. The base model will be 2 `Dense` layers of 128 neurons each, and have the `'relu'` activation.\n", "- Check out the documentation for [tf.keras.layers.Dense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "def base_model(inputs):\n", " \n", " # connect a Dense layer with 128 neurons and a relu activation\n", " x = Dense(units=128, activation='relu')(inputs)\n", " \n", " # connect another Dense layer with 128 neurons and a relu activation\n", " x = Dense(units=128, activation='relu')(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define output layers of the model\n", "\n", "You will add output layers to the base model. \n", "- The model will need two outputs.\n", "\n", "One output layer will predict wine quality, which is a numeric value.\n", "- Define a `Dense` layer with 1 neuron.\n", "- Since this is a regression output, the activation can be left as its default value `None`.\n", "\n", "The other output layer will predict the wine type, which is either red `1` or not red `0` (white).\n", "- Define a `Dense` layer with 1 neuron.\n", "- Since there are two possible categories, you can use a sigmoid activation for binary classification.\n", "\n", "Define the `Model`\n", "- Define the `Model` object, and set the following parameters:\n", " - `inputs`: pass in the inputs to the model as a list.\n", " - `outputs`: pass in a list of the outputs that you just defined: wine quality, then wine type.\n", " - **Note**: please list the wine quality before wine type in the outputs, as this will affect the calculated loss if you choose the other order." ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "def final_model(inputs):\n", " \n", " # get the base model\n", " x = base_model(inputs)\n", "\n", " # connect the output Dense layer for regression\n", " wine_quality = Dense(units='1', name='wine_quality')(x)\n", "\n", " # connect the output Dense layer for classification. this will use a sigmoid activation.\n", " wine_type = Dense(units='1', activation='sigmoid', name='wine_type')(x)\n", "\n", " # define the model using the input and output layers\n", " model = Model(inputs=inputs, outputs=[wine_quality, wine_type])\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compiling the Model\n", "\n", "Next, compile the model. When setting the loss parameter of `model.compile`, you're setting the loss for each of the two outputs (wine quality and wine type).\n", "\n", "To set more than one loss, use a dictionary of key-value pairs.\n", "- You can look at the docs for the losses [here](https://www.tensorflow.org/api_docs/python/tf/keras/losses#functions).\n", " - **Note**: For the desired spelling, please look at the \"Functions\" section of the documentation and not the \"classes\" section on that same page.\n", "- wine_type: Since you will be performing binary classification on wine type, you should use the binary crossentropy loss function for it. Please pass this in as a string. \n", " - **Hint**, this should be all lowercase. In the documentation, you'll see this under the \"Functions\" section, not the \"Classes\" section.\n", "- wine_quality: since this is a regression output, use the mean squared error. Please pass it in as a string, all lowercase.\n", " - **Hint**: You may notice that there are two aliases for mean squared error. Please use the shorter name.\n", "\n", "\n", "You will also set the metric for each of the two outputs. Again, to set metrics for two or more outputs, use a dictionary with key value pairs.\n", "- The metrics documentation is linked [here](https://www.tensorflow.org/api_docs/python/tf/keras/metrics).\n", "- For the wine type, please set it to accuracy as a string, all lowercase.\n", "- For wine quality, please use the root mean squared error. Instead of a string, you'll set it to an instance of the class [RootMeanSquaredError](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/RootMeanSquaredError), which belongs to the tf.keras.metrics module." ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "inputs = tf.keras.layers.Input(shape=(11,))\n", "rms = tf.keras.optimizers.RMSprop(learning_rate=0.0001)\n", "model = final_model(inputs)\n", "\n", "model.compile(optimizer=rms, \n", " loss = {'wine_type' : 'binary_crossentropy',\n", " 'wine_quality' : 'mean_squared_error'\n", " },\n", " metrics = {'wine_type' : 'accuracy',\n", " 'wine_quality': tf.keras.metrics.RootMeanSquaredError()\n", " }\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training the Model\n", "\n", "Fit the model to the training inputs and outputs. \n", "- Check the documentation for [model.fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit).\n", "- Remember to use the normalized training set as inputs. \n", "- For the validation data, please use the normalized validation set." ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/40\n", "99/99 [==============================] - 2s 12ms/step - loss: 23.0054 - wine_quality_loss: 22.3884 - wine_type_loss: 0.6170 - wine_quality_root_mean_squared_error: 4.7316 - wine_type_accuracy: 0.7423 - val_loss: 15.7310 - val_wine_quality_loss: 15.1511 - val_wine_type_loss: 0.5800 - val_wine_quality_root_mean_squared_error: 3.8924 - val_wine_type_accuracy: 0.7465\n", "Epoch 2/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 9.8386 - wine_quality_loss: 9.2897 - wine_type_loss: 0.5490 - wine_quality_root_mean_squared_error: 3.0479 - wine_type_accuracy: 0.7385 - val_loss: 5.6479 - val_wine_quality_loss: 5.1287 - val_wine_type_loss: 0.5191 - val_wine_quality_root_mean_squared_error: 2.2647 - val_wine_type_accuracy: 0.7452\n", "Epoch 3/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 3.8056 - wine_quality_loss: 3.3435 - wine_type_loss: 0.4621 - wine_quality_root_mean_squared_error: 1.8285 - wine_type_accuracy: 0.7594 - val_loss: 2.8688 - val_wine_quality_loss: 2.4544 - val_wine_type_loss: 0.4144 - val_wine_quality_root_mean_squared_error: 1.5667 - val_wine_type_accuracy: 0.8074\n", "Epoch 4/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 2.5611 - wine_quality_loss: 2.2106 - wine_type_loss: 0.3505 - wine_quality_root_mean_squared_error: 1.4868 - wine_type_accuracy: 0.8859 - val_loss: 2.3535 - val_wine_quality_loss: 2.0324 - val_wine_type_loss: 0.3210 - val_wine_quality_root_mean_squared_error: 1.4256 - val_wine_type_accuracy: 0.9138\n", "Epoch 5/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 2.1611 - wine_quality_loss: 1.8892 - wine_type_loss: 0.2719 - wine_quality_root_mean_squared_error: 1.3745 - wine_type_accuracy: 0.9429 - val_loss: 2.0623 - val_wine_quality_loss: 1.8073 - val_wine_type_loss: 0.2550 - val_wine_quality_root_mean_squared_error: 1.3444 - val_wine_type_accuracy: 0.9480\n", "Epoch 6/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.8903 - wine_quality_loss: 1.6780 - wine_type_loss: 0.2123 - wine_quality_root_mean_squared_error: 1.2954 - wine_type_accuracy: 0.9705 - val_loss: 1.8711 - val_wine_quality_loss: 1.6652 - val_wine_type_loss: 0.2059 - val_wine_quality_root_mean_squared_error: 1.2904 - val_wine_type_accuracy: 0.9670\n", "Epoch 7/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.6902 - wine_quality_loss: 1.5210 - wine_type_loss: 0.1692 - wine_quality_root_mean_squared_error: 1.2333 - wine_type_accuracy: 0.9791 - val_loss: 1.7000 - val_wine_quality_loss: 1.5310 - val_wine_type_loss: 0.1689 - val_wine_quality_root_mean_squared_error: 1.2373 - val_wine_type_accuracy: 0.9759\n", "Epoch 8/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.5363 - wine_quality_loss: 1.3980 - wine_type_loss: 0.1383 - wine_quality_root_mean_squared_error: 1.1824 - wine_type_accuracy: 0.9823 - val_loss: 1.5654 - val_wine_quality_loss: 1.4226 - val_wine_type_loss: 0.1428 - val_wine_quality_root_mean_squared_error: 1.1927 - val_wine_type_accuracy: 0.9747\n", "Epoch 9/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.4062 - wine_quality_loss: 1.2909 - wine_type_loss: 0.1153 - wine_quality_root_mean_squared_error: 1.1362 - wine_type_accuracy: 0.9848 - val_loss: 1.4485 - val_wine_quality_loss: 1.3264 - val_wine_type_loss: 0.1221 - val_wine_quality_root_mean_squared_error: 1.1517 - val_wine_type_accuracy: 0.9772\n", "Epoch 10/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.2977 - wine_quality_loss: 1.1993 - wine_type_loss: 0.0984 - wine_quality_root_mean_squared_error: 1.0951 - wine_type_accuracy: 0.9861 - val_loss: 1.3499 - val_wine_quality_loss: 1.2433 - val_wine_type_loss: 0.1066 - val_wine_quality_root_mean_squared_error: 1.1150 - val_wine_type_accuracy: 0.9797\n", "Epoch 11/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.1989 - wine_quality_loss: 1.1140 - wine_type_loss: 0.0850 - wine_quality_root_mean_squared_error: 1.0554 - wine_type_accuracy: 0.9876 - val_loss: 1.2570 - val_wine_quality_loss: 1.1623 - val_wine_type_loss: 0.0947 - val_wine_quality_root_mean_squared_error: 1.0781 - val_wine_type_accuracy: 0.9810\n", "Epoch 12/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.1149 - wine_quality_loss: 1.0395 - wine_type_loss: 0.0754 - wine_quality_root_mean_squared_error: 1.0196 - wine_type_accuracy: 0.9892 - val_loss: 1.1906 - val_wine_quality_loss: 1.1047 - val_wine_type_loss: 0.0859 - val_wine_quality_root_mean_squared_error: 1.0510 - val_wine_type_accuracy: 0.9835\n", "Epoch 13/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 1.0436 - wine_quality_loss: 0.9758 - wine_type_loss: 0.0678 - wine_quality_root_mean_squared_error: 0.9878 - wine_type_accuracy: 0.9899 - val_loss: 1.1063 - val_wine_quality_loss: 1.0278 - val_wine_type_loss: 0.0785 - val_wine_quality_root_mean_squared_error: 1.0138 - val_wine_type_accuracy: 0.9848\n", "Epoch 14/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 0.9753 - wine_quality_loss: 0.9135 - wine_type_loss: 0.0618 - wine_quality_root_mean_squared_error: 0.9558 - wine_type_accuracy: 0.9908 - val_loss: 1.0402 - val_wine_quality_loss: 0.9671 - val_wine_type_loss: 0.0730 - val_wine_quality_root_mean_squared_error: 0.9834 - val_wine_type_accuracy: 0.9861\n", "Epoch 15/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.9150 - wine_quality_loss: 0.8581 - wine_type_loss: 0.0569 - wine_quality_root_mean_squared_error: 0.9263 - wine_type_accuracy: 0.9918 - val_loss: 0.9950 - val_wine_quality_loss: 0.9267 - val_wine_type_loss: 0.0683 - val_wine_quality_root_mean_squared_error: 0.9627 - val_wine_type_accuracy: 0.9886\n", "Epoch 16/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 0.8608 - wine_quality_loss: 0.8077 - wine_type_loss: 0.0531 - wine_quality_root_mean_squared_error: 0.8987 - wine_type_accuracy: 0.9918 - val_loss: 0.9258 - val_wine_quality_loss: 0.8614 - val_wine_type_loss: 0.0644 - val_wine_quality_root_mean_squared_error: 0.9281 - val_wine_type_accuracy: 0.9886\n", "Epoch 17/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.8093 - wine_quality_loss: 0.7595 - wine_type_loss: 0.0498 - wine_quality_root_mean_squared_error: 0.8715 - wine_type_accuracy: 0.9918 - val_loss: 0.8713 - val_wine_quality_loss: 0.8101 - val_wine_type_loss: 0.0613 - val_wine_quality_root_mean_squared_error: 0.9000 - val_wine_type_accuracy: 0.9886\n", "Epoch 18/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.7664 - wine_quality_loss: 0.7193 - wine_type_loss: 0.0472 - wine_quality_root_mean_squared_error: 0.8481 - wine_type_accuracy: 0.9918 - val_loss: 0.8281 - val_wine_quality_loss: 0.7696 - val_wine_type_loss: 0.0584 - val_wine_quality_root_mean_squared_error: 0.8773 - val_wine_type_accuracy: 0.9886\n", "Epoch 19/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.7210 - wine_quality_loss: 0.6761 - wine_type_loss: 0.0449 - wine_quality_root_mean_squared_error: 0.8223 - wine_type_accuracy: 0.9918 - val_loss: 0.7787 - val_wine_quality_loss: 0.7224 - val_wine_type_loss: 0.0563 - val_wine_quality_root_mean_squared_error: 0.8500 - val_wine_type_accuracy: 0.9886\n", "Epoch 20/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.6851 - wine_quality_loss: 0.6421 - wine_type_loss: 0.0430 - wine_quality_root_mean_squared_error: 0.8013 - wine_type_accuracy: 0.9924 - val_loss: 0.7432 - val_wine_quality_loss: 0.6889 - val_wine_type_loss: 0.0543 - val_wine_quality_root_mean_squared_error: 0.8300 - val_wine_type_accuracy: 0.9886\n", "Epoch 21/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.6503 - wine_quality_loss: 0.6090 - wine_type_loss: 0.0413 - wine_quality_root_mean_squared_error: 0.7804 - wine_type_accuracy: 0.9921 - val_loss: 0.7027 - val_wine_quality_loss: 0.6499 - val_wine_type_loss: 0.0528 - val_wine_quality_root_mean_squared_error: 0.8061 - val_wine_type_accuracy: 0.9886\n", "Epoch 22/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.6166 - wine_quality_loss: 0.5767 - wine_type_loss: 0.0399 - wine_quality_root_mean_squared_error: 0.7594 - wine_type_accuracy: 0.9921 - val_loss: 0.6744 - val_wine_quality_loss: 0.6231 - val_wine_type_loss: 0.0513 - val_wine_quality_root_mean_squared_error: 0.7894 - val_wine_type_accuracy: 0.9886\n", "Epoch 23/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.5857 - wine_quality_loss: 0.5471 - wine_type_loss: 0.0386 - wine_quality_root_mean_squared_error: 0.7396 - wine_type_accuracy: 0.9930 - val_loss: 0.6430 - val_wine_quality_loss: 0.5929 - val_wine_type_loss: 0.0501 - val_wine_quality_root_mean_squared_error: 0.7700 - val_wine_type_accuracy: 0.9899\n", "Epoch 24/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.5587 - wine_quality_loss: 0.5213 - wine_type_loss: 0.0375 - wine_quality_root_mean_squared_error: 0.7220 - wine_type_accuracy: 0.9930 - val_loss: 0.6154 - val_wine_quality_loss: 0.5663 - val_wine_type_loss: 0.0491 - val_wine_quality_root_mean_squared_error: 0.7525 - val_wine_type_accuracy: 0.9924\n", "Epoch 25/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.5356 - wine_quality_loss: 0.4991 - wine_type_loss: 0.0365 - wine_quality_root_mean_squared_error: 0.7064 - wine_type_accuracy: 0.9933 - val_loss: 0.5934 - val_wine_quality_loss: 0.5453 - val_wine_type_loss: 0.0481 - val_wine_quality_root_mean_squared_error: 0.7384 - val_wine_type_accuracy: 0.9911\n", "Epoch 26/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.5144 - wine_quality_loss: 0.4788 - wine_type_loss: 0.0357 - wine_quality_root_mean_squared_error: 0.6919 - wine_type_accuracy: 0.9930 - val_loss: 0.5690 - val_wine_quality_loss: 0.5218 - val_wine_type_loss: 0.0473 - val_wine_quality_root_mean_squared_error: 0.7223 - val_wine_type_accuracy: 0.9911\n", "Epoch 27/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.4950 - wine_quality_loss: 0.4602 - wine_type_loss: 0.0348 - wine_quality_root_mean_squared_error: 0.6783 - wine_type_accuracy: 0.9930 - val_loss: 0.5516 - val_wine_quality_loss: 0.5051 - val_wine_type_loss: 0.0465 - val_wine_quality_root_mean_squared_error: 0.7107 - val_wine_type_accuracy: 0.9911\n", "Epoch 28/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.4776 - wine_quality_loss: 0.4435 - wine_type_loss: 0.0342 - wine_quality_root_mean_squared_error: 0.6659 - wine_type_accuracy: 0.9930 - val_loss: 0.5321 - val_wine_quality_loss: 0.4862 - val_wine_type_loss: 0.0458 - val_wine_quality_root_mean_squared_error: 0.6973 - val_wine_type_accuracy: 0.9911\n", "Epoch 29/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.4611 - wine_quality_loss: 0.4277 - wine_type_loss: 0.0334 - wine_quality_root_mean_squared_error: 0.6540 - wine_type_accuracy: 0.9933 - val_loss: 0.5207 - val_wine_quality_loss: 0.4754 - val_wine_type_loss: 0.0453 - val_wine_quality_root_mean_squared_error: 0.6895 - val_wine_type_accuracy: 0.9911\n", "Epoch 30/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 0.4476 - wine_quality_loss: 0.4147 - wine_type_loss: 0.0329 - wine_quality_root_mean_squared_error: 0.6440 - wine_type_accuracy: 0.9930 - val_loss: 0.5014 - val_wine_quality_loss: 0.4567 - val_wine_type_loss: 0.0447 - val_wine_quality_root_mean_squared_error: 0.6758 - val_wine_type_accuracy: 0.9911\n", "Epoch 31/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.4344 - wine_quality_loss: 0.4021 - wine_type_loss: 0.0322 - wine_quality_root_mean_squared_error: 0.6342 - wine_type_accuracy: 0.9933 - val_loss: 0.4937 - val_wine_quality_loss: 0.4494 - val_wine_type_loss: 0.0443 - val_wine_quality_root_mean_squared_error: 0.6703 - val_wine_type_accuracy: 0.9924\n", "Epoch 32/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 0.4235 - wine_quality_loss: 0.3919 - wine_type_loss: 0.0316 - wine_quality_root_mean_squared_error: 0.6260 - wine_type_accuracy: 0.9943 - val_loss: 0.4820 - val_wine_quality_loss: 0.4380 - val_wine_type_loss: 0.0440 - val_wine_quality_root_mean_squared_error: 0.6618 - val_wine_type_accuracy: 0.9911\n", "Epoch 33/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.4126 - wine_quality_loss: 0.3813 - wine_type_loss: 0.0313 - wine_quality_root_mean_squared_error: 0.6175 - wine_type_accuracy: 0.9937 - val_loss: 0.4675 - val_wine_quality_loss: 0.4239 - val_wine_type_loss: 0.0436 - val_wine_quality_root_mean_squared_error: 0.6511 - val_wine_type_accuracy: 0.9924\n", "Epoch 34/40\n", "99/99 [==============================] - 1s 11ms/step - loss: 0.4044 - wine_quality_loss: 0.3737 - wine_type_loss: 0.0308 - wine_quality_root_mean_squared_error: 0.6113 - wine_type_accuracy: 0.9946 - val_loss: 0.4645 - val_wine_quality_loss: 0.4213 - val_wine_type_loss: 0.0432 - val_wine_quality_root_mean_squared_error: 0.6491 - val_wine_type_accuracy: 0.9911\n", "Epoch 35/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3968 - wine_quality_loss: 0.3666 - wine_type_loss: 0.0302 - wine_quality_root_mean_squared_error: 0.6054 - wine_type_accuracy: 0.9946 - val_loss: 0.4503 - val_wine_quality_loss: 0.4074 - val_wine_type_loss: 0.0429 - val_wine_quality_root_mean_squared_error: 0.6383 - val_wine_type_accuracy: 0.9924\n", "Epoch 36/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3900 - wine_quality_loss: 0.3601 - wine_type_loss: 0.0299 - wine_quality_root_mean_squared_error: 0.6001 - wine_type_accuracy: 0.9946 - val_loss: 0.4430 - val_wine_quality_loss: 0.4004 - val_wine_type_loss: 0.0427 - val_wine_quality_root_mean_squared_error: 0.6327 - val_wine_type_accuracy: 0.9924\n", "Epoch 37/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3846 - wine_quality_loss: 0.3552 - wine_type_loss: 0.0294 - wine_quality_root_mean_squared_error: 0.5960 - wine_type_accuracy: 0.9946 - val_loss: 0.4366 - val_wine_quality_loss: 0.3942 - val_wine_type_loss: 0.0424 - val_wine_quality_root_mean_squared_error: 0.6279 - val_wine_type_accuracy: 0.9924\n", "Epoch 38/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3779 - wine_quality_loss: 0.3490 - wine_type_loss: 0.0289 - wine_quality_root_mean_squared_error: 0.5907 - wine_type_accuracy: 0.9946 - val_loss: 0.4317 - val_wine_quality_loss: 0.3896 - val_wine_type_loss: 0.0421 - val_wine_quality_root_mean_squared_error: 0.6241 - val_wine_type_accuracy: 0.9924\n", "Epoch 39/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3733 - wine_quality_loss: 0.3448 - wine_type_loss: 0.0285 - wine_quality_root_mean_squared_error: 0.5872 - wine_type_accuracy: 0.9946 - val_loss: 0.4358 - val_wine_quality_loss: 0.3938 - val_wine_type_loss: 0.0420 - val_wine_quality_root_mean_squared_error: 0.6276 - val_wine_type_accuracy: 0.9937\n", "Epoch 40/40\n", "99/99 [==============================] - 1s 10ms/step - loss: 0.3685 - wine_quality_loss: 0.3403 - wine_type_loss: 0.0282 - wine_quality_root_mean_squared_error: 0.5834 - wine_type_accuracy: 0.9946 - val_loss: 0.4327 - val_wine_quality_loss: 0.3909 - val_wine_type_loss: 0.0418 - val_wine_quality_root_mean_squared_error: 0.6252 - val_wine_type_accuracy: 0.9937\n" ] } ], "source": [ "history = model.fit(x=norm_train_X, y=train_Y,\n", " epochs = 40, validation_data=(norm_val_X, val_Y))" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25/25 [==============================] - 0s 4ms/step - loss: 0.4327 - wine_quality_loss: 0.3909 - wine_type_loss: 0.0418 - wine_quality_root_mean_squared_error: 0.6252 - wine_type_accuracy: 0.9937\n", "\n", "loss: 0.43266481161117554\n", "wine_quality_loss: 0.39088746905326843\n", "wine_type_loss: 0.04177739843726158\n", "wine_quality_rmse: 0.6252099275588989\n", "wine_type_accuracy: 0.9936628937721252\n" ] } ], "source": [ "# Gather the training metrics\n", "loss, wine_quality_loss, wine_type_loss, wine_quality_rmse, wine_type_accuracy = model.evaluate(x=norm_val_X, y=val_Y)\n", "\n", "print()\n", "print(f'loss: {loss}')\n", "print(f'wine_quality_loss: {wine_quality_loss}')\n", "print(f'wine_type_loss: {wine_type_loss}')\n", "print(f'wine_quality_rmse: {wine_quality_rmse}')\n", "print(f'wine_type_accuracy: {wine_type_accuracy}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze the Model Performance\n", "\n", "Note that the model has two outputs. The output at index 0 is quality and index 1 is wine type\n", "\n", "So, round the quality predictions to the nearest integer." ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "predictions = model.predict(norm_test_X)\n", "quality_pred = predictions[0]\n", "type_pred = predictions[1]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([6.014229], dtype=float32)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "quality_pred[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Utilities\n", "\n", "We define a few utilities to visualize the model performance." ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "def plot_metrics(history, metric_name, title, ylim=5):\n", " plt.title(title)\n", " plt.ylim(0,ylim)\n", " plt.plot(history.history[metric_name],color='blue',label=metric_name)\n", " plt.plot(history.history['val_' + metric_name],color='green',label='val_' + metric_name)\n" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "def plot_confusion_matrix(y_true, y_pred, title='', labels=[0,1]):\n", " cm = confusion_matrix(y_true, y_pred)\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " cax = ax.matshow(cm)\n", " plt.title('Confusion matrix of the classifier')\n", " fig.colorbar(cax)\n", " ax.set_xticklabels([''] + labels)\n", " ax.set_yticklabels([''] + labels)\n", " plt.xlabel('Predicted')\n", " plt.ylabel('True')\n", " fmt = 'd'\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt),\n", " horizontalalignment=\"center\",\n", " color=\"black\" if cm[i, j] > thresh else \"white\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "def plot_diff(y_true, y_pred, title = '' ):\n", " plt.scatter(y_true, y_pred)\n", " plt.title(title)\n", " plt.xlabel('True Values')\n", " plt.ylabel('Predictions')\n", " plt.axis('equal')\n", " plt.axis('square')\n", " plt.plot([-100, 100], [-100, 100])\n", " return plt" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_metrics(history, 'wine_quality_root_mean_squared_error', 'RMSE', ylim=2)" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_metrics(history, 'wine_type_loss', 'Wine Type Loss', ylim=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plots for Confusion Matrix\n", "\n", "Plot the confusion matrices for wine type. You can see that the model performs well for prediction of wine type from the confusion matrix and the loss metrics." ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\kcsgo\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:8: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " \n", "C:\\Users\\kcsgo\\anaconda3\\lib\\site-packages\\ipykernel_launcher.py:9: UserWarning: FixedFormatter should only be used together with FixedLocator\n", " if __name__ == '__main__':\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_confusion_matrix(test_Y[1], np.round(type_pred), title='Wine Type', labels = [0, 1])" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "scatter_plot = plot_diff(test_Y[0], quality_pred, title='Type')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }