{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The Keras Functional API\n", "> In this chapter, you'll become familiar with the basics of the Keras functional API. You'll build a simple functional network using functional building blocks, fit it to data, and make predictions. This is the Summary of lecture \"Advanced Deep Learning with Keras\", via datacamp.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Datacamp, Tensorflow-Keras, Deep_Learning]\n", "- image: images/plot_model.png" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "plt.rcParams['figure.figsize'] = (8, 8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Keras input and dense layers\n", "- Inputs and outputs\n", " - Input layer\n", " - Output layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Input layers\n", "The first step in creating a neural network model is to define the Input layer. This layer takes in raw data, usually in the form of numpy arrays. The shape of the Input layer defines how many variables your neural network will use. For example, if the input data has 10 columns, you define an Input layer with a shape of `(10,)`.\n", "\n", "In this case, you are only using one input in your network." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.layers import Input\n", "\n", "# Create an input layer of shape 1\n", "input_tensor = Input(shape=(1, ))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dense layers\n", "Once you have an Input layer, the next step is to add a Dense layer.\n", "\n", "Dense layers learn a weight matrix, where the first dimension of the matrix is the dimension of the input data, and the second dimension is the dimension of the output data. Recall that your Input layer has a shape of 1. In this case, your output layer will also have a shape of 1. This means that the Dense layer will learn a 1x1 weight matrix.\n", "\n", "In this exercise, you will add a dense layer to your model, after the input layer." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.layers import Dense\n", "\n", "# Input layer\n", "input_tensor = Input(shape=(1, ))\n", "\n", "# Dense layer\n", "output_layer = Dense(1)\n", "\n", "# Connect the dense layer to the input_tensor\n", "output_tensor = output_layer(input_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Output layers\n", "Output layers are simply Dense layers! Output layers are used to reduce the dimension of the inputs to the dimension of the outputs. You'll learn more about output dimensions in chapter 4, but for now, you'll always use a single output in your neural networks, which is equivalent to `Dense(1)` or a dense layer with a single unit." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Input layer\n", "input_tensor = Input(shape=(1, ))\n", "\n", "# Create a dense layer and connect the dense layer to the input_tensor in one step\n", "# Note that we did this in 2 steps in the previous exercise, but are doing it in one step now\n", "output_tensor = Dense(1)(input_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build and compile a model\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build a model\n", "Once you've defined an input layer and an output layer, you can build a Keras model. The model object is how you tell Keras where the model starts and stops: where data comes in and where predictions come out.\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Model\n", "\n", "input_tensor = Input(shape=(1, ))\n", "output_tensor = Dense(1)(input_tensor)\n", "\n", "# Built the model\n", "model = Model(input_tensor, output_tensor)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compile a model\n", "The final step in creating a model is compiling it. Now that you've created a model, you have to compile it before you can fit it to data. This finalizes your model, freezes all its settings, and prepares it to meet some data!\n", "\n", "During compilation, you specify the optimizer to use for fitting the model to the data, and a loss function. `'adam'` is a good default optimizer to use, and will generally work well. Loss function depends on the problem at hand. Mean squared error is a common loss function and will optimize for predicting the mean, as is done in least squares regression.\n", "\n", "Mean absolute error optimizes for the median and is used in quantile regression. For this dataset, `'mean_absolute_error'` works pretty well, so use it as your loss function." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Compile the model\n", "model.compile(optimizer='adam', loss='mean_absolute_error')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize a model\n", "Now that you've compiled the model, take a look a the result of your hard work! You can do this by looking at the model summary, as well as its plot.\n", "\n", "The summary will tell you the names of the layers, as well as how many units they have and how many parameters are in the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: Before using `plot_model`, you need to install pydot, pydotplus, and graphviz. After install them, restart the kernel.\n", "```\n", "sudo apt install graphviz\n", "pip install pydot pydotplus graphviz\n", "```" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_1\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input_5 (InputLayer) [(None, 1)] 0 \n", "_________________________________________________________________\n", "dense_3 (Dense) (None, 1) 2 \n", "=================================================================\n", "Total params: 2\n", "Trainable params: 2\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAFfCAYAAACbVm9yAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deZRcVZ3A8e/PhLAjS4KGJEyAQRRBJbSYgCCHRQMCQUxYZAkSjYwgILIKAodNEARlEDwRkaDscZCMuMBEgVEI0mFAwLAEEIgECCIoqGDgzh9dl+rudHV3ekndrv5+zulTVa9evfq9eq/6V7/37r0vUkpIkqQyvKPeAUiSpCoTsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQXpt8QcERMj4pGIWBARJ/TX+0iS1EiiP/oxR8QQ4FFgZ2AhcA+wX0rpD33+ZpIkNZD+qpi3AhaklJ5IKb0BXAtM6qf3kiSpYQztp+WOAp5p9Xgh8JFaMw8fPjyNHTu2n0KRJKk88+bNezGlNKL99P5KzNHBtDbHzCNiOjAdYP3116e5ubmfQpEkqTwR8VRH0/vrUPZCYEyrx6OBZ1vPkFKakVJqSik1jRix1A8GSZIGpf5KzPcAG0fEBhExDNgXmN1P7yVJUsPol0PZKaUlEXE48EtgCHB5Sumh/ngvSZIaSX+dYyal9DPgZ/21fEmSGpEjf0mSVBATsyRJBTExS5JUkH47x9yX7rrrLgAuuOCCOkciSVLnJkyYAMDRRx/do9dbMUuSVJABUTE/80zL6J6zZs0CYPLkyfUMR5KkpcydO7dPlmPFLElSQQZExdzeDTfcUO8QJElqY8qUKX2yHCtmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpID1OzBExJiJ+HRHzI+KhiDiyMn3tiLg1Ih6r3K7Vd+FKktTYelMxLwG+klJ6HzAeOCwiNgVOAOaklDYG5lQeS5KkbuhxYk4pLUop3Vu5/zdgPjAKmATMrMw2E9izt0FKkjRY9Mk55ogYC2wB3A28K6W0CFqSN7BujddMj4jmiGhevHhxX4QhSdKAN7S3C4iI1YAfA0ellP4aEd16XUppBjADoKmpKfU2Dqm7PvrRjwLw29/+dpled+SRRwLwrW99q89j6o7VVlsNgNdee63T+c477zwAjjnmmH6PSS26u23A7aOu9apijogVaEnKV6WU/qsy+fmIGFl5fiTwQu9ClCRp8OhxxRwtpfH3gfkppQtaPTUbmAqcU7m9qVcRLkevvvoqAFtssQUAm2yyCQA//elP6xaTlOX987777gOq++mkSZMA+MlPflKfwNTltgG3j7qvN4eytwEOBB6IiPsq075KS0K+PiKmAU8DU3oXoiRJg0ePE3NK6TdArRPKO/Z0ufWUUsup7rfeeqvNrXovn4P70Ic+BMBvfvObeoYDwD333ANAU1NTnSNRKUrcTzX4OPKXJEkF6XWr7Eay+uqrA/D444/XORJJ0mBlxSxJUkFMzJIkFcTELElSQTzHTLV/4ac+9akOn//HP/4BwEorrdRm/s5e8+STTwJw/PHHA/DLX/4SgCFDhgAwYcIEAL797W8DsNFGG7V5/fnnnw/Ascce22b6qFGjALjxxhsBOOGElmuE/O53vwPgzTffBOAjH/kIAGeeeSYA22yzTZvl5Olf+9rX2kzP87VvjfqLX/wCgF122aXN9HXWWQeAF198sVvx59G22o8Qlz+XJUuWsLz88Ic/BOBzn/scAI899hgAQ4e2fC0+8IEPAPAf//EfAHzmM59ZbrH1ha7267yPgvtpyftplt/zxz/+MQCXXXYZAA888AAAr7zyCgD//u//DlT36y996UsAvOMdLXXYyy+/DMBaa3V+4b8zzjgDgJNPPrnN+6+wwgodzv/pT38agFmzZrWZnodczsubPXs2AM8++ywA73znOwHYdtttATjllFOAasv4rKv9GeDhhx8GqvvLnDlzAHjppZc6jGn48OE1l1VPVsySJBUkct/dempqakrNzc01n7/++usB2GeffYBqf+P+sueeLRfEuummlkHL2lfM3XlNHvEnVyK5+rrrrrsA2GOPPQDYbLPNgGolUUv+9bhgwYI2y/vmN7/Z5vH8+fMBmDZtGlD9BXnLLbcA8LGPfazD5S9r/83c9/ePf/wjsHQl0tvl96c8VvaGG24IwFFHHQVUR3pbuHAhAGeddRZQraxz5XHRRRd1uvwddtgBgPvvvx+Am2++GYDx48f3Sfw9Hfmr1j4K7qc9XX57/TnyVx6BcPfddwfg7LPPBuDQQw8FqkchrrnmGqC6Xx999NFAdYzu9iZOnAjArbfeCsCjjz4KLH10pL2tt94agMMPPxxY+ojSokWLgOpRl3/+858AXH755QBst912ADz11FMAHHbYYQDcfffdAPzqV79q8/r22u/PUN1vTjvtNAC22moroHpUIR9pee6554C+r5inTGk7ntYNN9zQ6fwRMS+ltNRAClbMkiQVxHPM/SSf32n/a2+nnXYC4JOf/CRQPR+Tf8l39QsuX73mkksuAZY+D5MrhB/96EdAtULJV0bKv+gHs66qoFw5X3nllQA88sgjAPznf/4nAPvvvz9QPT/aXh4xLh/ZKeGoVEfyPgrupwPJ9ttvD8CJJ57Y4fP5yE4+upHbB+TzrmussUab+XNFndsXXHBBy6UPvvOd73S4/Hz+/emnnwaWrhKzHF+uiK+66ioAdt111zbzvf/97wfg2muvBWDs2LFt1qOzo6nt5SM/+TPK8ne1Hm0DesKKWZKkglgx95MPf/jDnT4/ZsyYNo9zC8WuKpFVV10VWLoCaW/zzTcHYL311gOq5zvzeZ+RI0d2+npVTZ48GahWIP/93/8N1K6Yb7vttuUSV291tY+C+2lJdttttza3XfngBz8IVI9KPPTQQ8DSR0c+/vGPA9VtccUVVwBw+umnA9UW7Vk+V50r2lqttPM59dwavKu43/3udwPVCnrevHlAtc3H6NGjO309VM8pD3RWzJIkFcSKuZ/kvnm1DBs2rM3j7l7Jas0111ymONZdd12gWum88MILgJXIsmj/WeXPcKDrah8F99OS5H7KuYV77iOeK8rcP7mWv//9750+n1tx55byuX1APjedW2vn1tI/+MEPOlzO66+/3iberDv7W0fy+ALdqZjzkZqBzopZkqSCWDEPMH/+85+Bakvf9iMTtde+usuVSXv5PNAbb7zRrTi6+nXeXldxlixXcVmtz1BV7qd9L/df/t///V+g2tp6v/32A6rn/fM6fOtb3wLgy1/+MtB174Dc2+CrX/0qABdffDEAxx13HFCt1KdOnQrUHjlsxRVXBKpHTV599VWgOh5EHllPtVkxS5JUEH+6DDB59Jx77rkHqN0KMY90k6u93EKz1jm7PP1Pf/pTp++fR8zJfRjb94msZZVVVgFqVzq57/BXvvIVAKZPn96t5S6LPLbwpZdeClRbfdaSK4w88lyWKxfV5n7ae7myzC3Vc//h3Hr5iCOO6PT1uULtrlzpfvGLXwTg1FNPBaqVcu6H3N0+5nvttRdQHekrx19rVLfs3HPPBar9qJ944glgcFXaVsySJBVk8PwEaRC5ZWM+D5Sv2FJrDOLcqjafj6ol92XM55Xy7cEHHwzA888/3+Z98znAXBl1Zdy4cUB1DOZnnnkGqLYozb+K8xVm+tO9994LVMfmzefgcqvPPFJR/mxzZZ37bdbqv5z191jZA4H7ad/JV7TKo1nlVtG5P/FnP/tZoNoiee7cuQB897vf7dH75Yr5nHPOAapXl8rjpuerV3Xl61//OgC33347AIcccghQ3WZ5rO08xnceVzr3n86tvgdTpZxZMUuSVJDB91OkA11d53PllVcGqq0W89VUoPaVT/JrTjrpJKB6XdlarT7z1Wjy2MT5SjLt5avf5HGbc7V35513AtWxYPM5vfzruv11btvL8eXKIl+5JrfI3HLLLQG48MILAXj88ceBajWZ1yuPVZt/bWe5hejnP/95AN73vvcBsPbaawPVSilP7w8HHnggUG0tevXVVwPVq+vkqihfRSxvkzxfbv3albwN+nqs7Lzt8zjUWb66Tt4GuZLKV9Hqah8F99Osp/tpV9um9Xv31HXXXQdUK9j82eZtl2PM16LOV3vK67jzzjsD1c+o1hjUuXX3AQccAMD3vvc9oDqmdnfloxV5xLx8tbb8/zMfjcjfx7xv5c8sj9ee5SMBtfZnaLtPQ7nj1HfFilmSpIJ4PeYBIo85nK/uk6s7qSTup40jn+PNraOX5SpPg5XXY5YkqQF5jlmStJTcqntZzy2r96yYJUkqiBWzJA1CeSS83M/+hz/8IVC9fvNf/vIXAPbee+86RDe4WTFLklQQK+ZCnX/++QAce+yxHT6f+0S273/aKPryKj95zN/TTjutz5apFoN9P20EeRyHfLWoTTfdFIBrr70WGJwjb9WbFbMkSQXpdT/miBgCNAN/SintFhEbANcCawP3AgemlDq9eKr9mCVJA11J/ZiPBOa3enwucGFKaWPgL8C0PngPSZIGhV4l5ogYDXwSuKzyOIAdgFmVWWYCe/bmPSRJGkx6WzF/CzgOeKvyeB3g5ZTSksrjhcCojl4YEdMjojkimhcvXtzLMCRJagw9TswRsRvwQkppXuvJHcza4QnhlNKMlFJTSqlpxIgRPQ1DkqSG0pt28NsAe0TErsBKwBq0VNBrRsTQStU8Gni292FKkjQ49LhiTimdmFIanVIaC+wL/CqltD/wa2ByZbapwE01FiFJktrpj37MxwNHR8QCWs45f78f3kOSpIbUJ0O6pJRuA26r3H8C2KovlitJ0mDjyF+SJBXExCxJUkFMzJIkFcTELElSQUzMkiQVxMQsSVJBTMySJBXExCxJUkFMzJIkFcTELElSQUzMkiQVxMQsSVJBTMySJBXExCxJUkFMzJIkFcTELElSQUzMkiQVxMQsSVJBTMySJBVkaL0D6IkpU6bUOwRJktqYO3cuAOPHj+/VcqyYJUkqyIComMeMGQPA5MmT6xyJ1DjuuOOOt++/733vA2DEiBH1Ckca8HKlPGHChF4tx4pZkqSCREqp3jHQ1NSUmpub6x2GNKhExNv3r7vuOgD23nvveoUjDToRMS+l1NR+uhWzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkF6VVijog1I2JWRDwcEfMjYkJErB0Rt0bEY5XbtfoqWEmSGl1vK+ZvA79IKb0X+CAwHzgBmJNS2hiYU3ksSZK6oceJOSLWALYDvg+QUnojpfQyMAmYWZltJrBnb4OUJGmw6E3FvCGwGPhBRPxfRFwWEasC70opLQKo3K7bB3FKkjQo9CYxDwXGAZemlLYAXmMZDltHxPSIaI6I5sWLF/ciDEmSGkdvEvNCYGFK6e7K41m0JOrnI2IkQOX2hY5enFKakVJqSik1eak5SZJa9Dgxp5SeA56JiE0qk3YE/gDMBqZWpk0FbupVhJIkDSJDe/n6LwFXRcQw4Angs7Qk++sjYhrwNDCll+8hSdKg0avEnFK6D1jqWpK0VM+SJGkZOfKXJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklSQXiXmiPhyRDwUEQ9GxDURsVJEbBARd0fEYxFxXUQM66tgJUlqdD1OzBExCjgCaEopbQYMAfYFzgUuTCltDPwFmNYXgUqSNBj09lD2UGDliBgKrAIsAnYAZlWenwns2cv3kCRp0OhxYk4p/Qk4H3ialoT8CjAPeDmltKQy20JgVEevj4jpEdEcEc2LFy/uaRiSJDWU3hzKXguYBGwArAesCuzSwaypo9enlGaklJpSSk0jRozoaRiSJDWU3hzK3gl4MqW0OKX0L+C/gK2BNSuHtgFGA8/2MkZJkgaN3iTmp4HxEbFKRASwI/AH4NfA5Mo8U4GbeheiJEmDR2/OMd9NSyOve4EHKsuaARwPHB0RC4B1gO/3QZySJA0KQ7uepbaU0qnAqe0mPwFs1ZvlSpI0WDnylyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklQQE7MkSQXp1VjZksrzhS98AYBHHnmk0/mGDq1+/U8//XQALrnkkg7nHTJkCAAzZ84EYPTo0b2OU1LHrJglSSqIFbPUYNZdd10AZsyY0e3XPPTQQ50+v8EGGwBWytLyYMUsSVJBrJilBrP//vsDcOaZZ/Z6WcOGDQPg4IMP7vWyJHWPFbMkSQWxYpYazHvf+14ANt10UwDmz58PQEppmZf1xhtvALDvvvv2UXSSumLFLElSQayYpQZ10EEHAXDyyScDsGTJkm6/NiIA2HzzzQF4z3ve08fRSarFilmSpIJYMUsN6jOf+QwAJ5544jK/No8KNnXq1D6NSVLXrJglSSqIFbPUoMaMGQPAVlttBcA999wDwFtvvdXla/P56H322aefopNUixWzJEkFsWKWGlxund3c3NzlvO94R8tv9a233hqAUaNG9V9gkjpkxSxJUkGsmKUGt/feewNwxBFHdDlv7r+cq2xJy58VsyRJBbFilhrc8OHDAdhxxx0BmDNnDgBvvvlmzdfstdde/R+YpA5ZMUuSVJAuE3NEXB4RL0TEg62mrR0Rt0bEY5XbtSrTIyIuiogFEfH7iBjXn8FLktRoulMxXwFMbDftBGBOSmljYE7lMcAuwMaVv+nApX0TpqTeOuCAAzjggANIKS11CcghQ4YwZMgQJk6cyMSJE1lnnXVYZ5116hSpNLh1mZhTSncAL7WbPAmYWbk/E9iz1fQrU4u5wJoRMbKvgpUkqdH1tPHXu1JKiwBSSosiYt3K9FHAM63mW1iZtqjnIUrqC3vu2fL7eYUVVgDg9ddff/u5XEEfcMAByz8wSW30deOv6GBa6mAaETE9Ipojonnx4sV9HIYkSQNTTyvm5yNiZKVaHgm8UJm+EBjTar7RwLMdLSClNAOYAdDU1NRh8pbUd1ZffXUAdt99dwBmzZr19nPDhg1r85yk+ulpxTwbyBdqnQrc1Gr6QZXW2eOBV/Ihb0mS1LUuK+aIuAbYHhgeEQuBU4FzgOsjYhrwNDClMvvPgF2BBcDfgc/2Q8xFWrhwIQB33nlnnSOROjd27Nilpo0b19Kz8eabb17O0UjLJl/OdMKECXWOpP90mZhTSvvVeGrHDuZNwGG9DUqSpMHKITn7SK6UvbC8BqK8/3rER6WbPHkyADfccEOdI+k/DskpSVJBrJj7SfuRlaTSHHPMMW/fP/vss4Fq62ypNFOmTOl6pgZhxSxJUkGsmKVB6owzznj7vpWyVA4rZkmSCmLFLA1SK6+8cr1DkNQBK2ZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKMrTeAah/XHvttQDst99+baavuOKKAPzzn/9c7jGpb6WUALjzzjsBuPrqqwG49dZbAXjqqacAeOc73wnAe97zHgAOPfRQAPbff/+3lxURyyHipa222moAvPbaa53Ol+PL67L++usDsM022wAwbdo0ALbccst+iVNanqyYJUkqSJcVc0RcDuwGvJBS2qwy7Txgd+AN4HHgsymllyvPnQhMA94Ejkgp/bKfYlcn9t133za3O+20EwC/+c1v6haT+tYjjzwCwEc/+lGguo1nzZoFVCvkZ599FoCTTz4ZgAMPPBCA++677+1lnX/++csh4qW9+uqrbWLZYostAJg0aRIAP/nJTwB48803AXjxxRcBmDt3LgDf/va3AWhqagLg4IMPBuA73/kOAKusskq/xi/1h+5UzFcAE9tNuxXYLKX0AeBR4ESAiNgU2Bd4f+U1l0TEkD6LVpKkBtdlxZxSuiMixrabdkurh3OByZX7k4BrU0qvA09GxAJgK+CuPolW0lKGDm35Gl9//fUArLXWWm2e33DDDQG44oorALjllpav78UXX/z2PGeddRZQbYNQmiFDWn7fv+td7wKqFXW+Pf744wH4xje+AcBLL70EVCvuep1Dl3qiL84xHwL8vHJ/FPBMq+cWVqYtJSKmR0RzRDQvXry4D8KQJGng61Wr7Ig4CVgCXJUndTBb6ui1KaUZwAyApqamDueRVNt73/teAP71r391a/5hw4YBMGbMGKDtOebcSr/Uirkr55xzDgC33347ALNnzwZq906QStbjxBwRU2lpFLZjyv02WirkMa1mGw082/PwJEkaXHqUmCNiInA88LGU0t9bPTUbuDoiLgDWAzYGftfrKCX12ssvvwzAY489BlRbQEO1f/BAlc8hH3744QDcfffdAFxyySWAFbMGlu50l7oG2B4YHhELgVNpaYW9InBr5QsxN6V0aErpoYi4HvgDLYe4D0spvdlfwUuS1Gi60yq7o5+a3+9k/rOAs3oTlLr28MMPA3DCCScA8Otf/xqAJUuWADBu3DgAvv71r/do+blB3hlnnAFUz9nlPrG5wtp2220BOOWUUwD40Ic+1GY5uVXspz71qQ7f58knnwSqrWp/+cuWbu+5FS7AhAkTgGqf1Y022qjDZb3++utAtYVxbqX89NNPA7DSSisB1dGiPv/5zwPwyU9+cqn3bK2vPovl7a9//SsADzzwAAAnnngiAO9+97sBuPLKK+sTWD/Kfbqz3N85n4dfYYUV2jy/vPZz6Hpfdz9X5shfkiQVxLGyB5gFCxYA1V/Xq666KlAd7SlPz7/OjznmGAAef/zxbi1/0aJFbZaTW+tefvnlAGy33XZAdRzmww47rM38v/rVr9o83nPPPYHquM758U033QTAUUcdBVSriPw+d91V7fq+xx57ANXzhL/7XcfNFvL5xRtuuKHNba6icgWZR7nKfWDz0Ybtt9++Xz+L5eXMM88E4Gtf+1qb6Xn9brzxRgA222yzbi9zhx12AOD+++8H4OabbwZg/PjxvYq1r+WjAVk+gpRHDBs5ciSw/Pdz6Hpfdz9XZsUsSVJBrJgHmK9+9atAtYXtZZddBsDOO+/cZr7NN98cgB/84AdAdfSnruTzkPnX8VVXtXRR33XXXdvM9/73vx+o9hMdO3YsAF/60pcAaG5u7tb7fe5znwOW/rWdx32G6rmxfFQgVz/Dhw9v85o5c+a0ia39Z7LyyisDcN555wHVc2i1LO/Poq/kMbGPO+44AJ544gmgeu4yt8Y+9dRT335N++q6vbfeeguoVoTVHpJl6W5c9di2Xe3r7ufKrJglSSqIFfMA84tf/KLN40984hOdzr/eeusB1SsNPfroo53On1uXvuMdLb/Zdtttt07nz+f08q/pefPmAbBw4UIARo8e3enrP/zhD3f6PFRHqspyK9H2lcTEiS3XWrn00ksBmD59OgCHHHJIm/fKrVLz1ZlqWd6fRV/LI33lEcLy5/L8888D1Va1UK3iWh+paO22227rrzD7VD5fmuVW2O33lXps2672dfdzZVbMkiQVxIp5gMh9F//2t78B1b6Kq622Wrdev+666wK1K+a8/FdeeaXN9J6OCJVHl+rq13N3lp8rvyyf72wvX4M3V38zZ84EYMcdd2wzX+6H+YUvfAFYuu9pvT6L5WX33XcHqq2zAX76058CtSvmgaL99cbzvpAr53pu267ew/1cmRWzJEkFsWIeIPJVf1ZffXWgWjm/+uqrQNeVc74+bVfLX3PNNdss9x//+AdQveZvyfJ4yQceeGCb2zzqUz5Pmvt37rXXXgB885vfBODoo48GGuOz6ExHV5Dqav8oXa4uczWZ5T63WSNsW/fzxmfFLElSQfxJNMDssssuQHV83NxKe/LkyR3On/tCdtUyM8u/rvOoP7/97W8B+NjHPtbp684991ygWrHkvrPL81d3/uWfx0fOrZHz+cXc3zOPkJRHTcujWOVKIiv9s8ijuj333HMA/OhHP+rW637+858vNa07reNLlvvi5tGy8vnUKVOmdDh/6du2M4NtPx+MrJglSSqIP20GmLPPPhuA//mf/wGq4+/mlpS5pWa+0kz+dZzPQbdvgdlevhrV7bffDlT7Rl588cUAbL311gC8+WbL1TzzOL2nn346UB1prJ6/mg899FAALrroIgA22WQToLruuf9nHiUqjwPd3kD5LK6++mqg2lf9oIMOApYeFzqvd66st9xyy7eXkUelqqXeY2Xnc8j5CFAeXzpv4zxec/ttlM/HtjdQtm1nBtt+PphYMUuSVJAoYczbpqamNNDHWc3nfPfZZx+g/8cSzv2R85VqcsWQW2bmKwfl8ZAvvPBCoDrObjZt2jSgOuZ2llvp5mu+5tGBnnnmGaB6niuPu3zssccCS/eDzefBurryzEknnQRUr4xUq9JpLY8tnPvg5mouVwp33HEHUB0DOPf9zpVlXvd8W+s9++qz6Gv5KkLtrzKU2xPkkaNy/9hcUeX2CEceeeTby8rjK9eSrzD04IMPAtWKubdXFMpHcl577bVO58vbZo011gBg/fXXB6rnUXPFn69D3l3Lez+HZd/XB/t+nrVvL5D394EsIuallJraT7diliSpIFbMfWR5V8ySNJhYMUuSpLowMUuSVBATsyRJBTExS5JUEBOzJEkFcagWaTnpTt/s7sh900877bQ+WZ6kslgxS5JUECtmaTmxb7uk7rBiliSpICtM4O8AAAd/SURBVCZmSZIKYmKWJKkgJmZJkgpiYpYkqSBdJuaIuDwiXoiIBzt47piISBExvPI4IuKiiFgQEb+PiGW7OKokSYNcdyrmK4CJ7SdGxBhgZ+DpVpN3ATau/E0HLu19iJIkDR5dJuaU0h3ASx08dSFwHNC6c+Yk4MrUYi6wZkSM7JNIJUkaBHp0jjki9gD+lFK6v91To4BnWj1eWJnW0TKmR0RzRDQvXry4J2FIktRwljkxR8QqwEnAKR093cG0Doc7SinNSCk1pZSaRowYsaxhSJLUkHoyJOdGwAbA/ZVB+UcD90bEVrRUyGNazTsaeLa3QUqSNFgsc8WcUnogpbRuSmlsSmksLcl4XErpOWA2cFCldfZ44JWU0qK+DVmSpMbVne5S1wB3AZtExMKImNbJ7D8DngAWAN8DvtgnUUqSNEh0eSg7pbRfF8+PbXU/AYf1PixJkgYnR/6SJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgpiYpYkqSAmZkmSCmJiliSpICZmSZIKYmKWJKkgJmZJkgrS5WUf1TPXX399vUOQpIaxcOFCAEaPHl3nSPqfFbMkSQWxYu4n++yzT71DkKSGM3ny5HqH0O+smCVJKkiklOodA01NTam5ubneYUiStNxExLyUUlP76VbMkiQVxMQsSVJBTMySJBXExCxJUkFMzJIkFcTELElSQUzMkiQVxMQsSVJBTMySJBWkiJG/ImIx8BrwYr1j6SPDaYx1aZT1gMZZl0ZZD2icdWmU9YDGWZeBsh7/llIa0X5iEYkZICKaOxqabCBqlHVplPWAxlmXRlkPaJx1aZT1gMZZl4G+Hh7KliSpICZmSZIKUlJinlHvAPpQo6xLo6wHNM66NMp6QOOsS6OsBzTOugzo9SjmHLMkSSqrYpYkadArIjFHxMSIeCQiFkTECfWOp7siYkxE/Doi5kfEQxFxZGX6aRHxp4i4r/K3a71j7Y6I+GNEPFCJubkybe2IuDUiHqvcrlXvODsTEZu0+tzvi4i/RsRRA2WbRMTlEfFCRDzYalqH2yBaXFT53vw+IsbVL/K2aqzHeRHxcCXWGyNizcr0sRHxj1bb5rv1i3xpNdal5v4UESdWtskjEfGJ+kS9tBrrcV2rdfhjRNxXmV76Nqn1v3fAfVc6lFKq6x8wBHgc2BAYBtwPbFrvuLoZ+0hgXOX+6sCjwKbAacAx9Y6vB+vzR2B4u2nfAE6o3D8BOLfecS7D+gwBngP+baBsE2A7YBzwYFfbANgV+DkQwHjg7nrH38V6fBwYWrl/bqv1GNt6vtL+aqxLh/tT5ft/P7AisEHlf9uQeq9DrfVo9/w3gVMGyDap9b93wH1XOvoroWLeCliQUnoipfQGcC0wqc4xdUtKaVFK6d7K/b8B84FR9Y2qz00CZlbuzwT2rGMsy2pH4PGU0lP1DqS7Ukp3AC+1m1xrG0wCrkwt5gJrRsTI5RNp5zpaj5TSLSmlJZWHc4HRyz2wHqixTWqZBFybUno9pfQksICW/3F119l6REQAewPXLNegeqiT/70D7rvSkRIS8yjgmVaPFzIAk1tEjAW2AO6uTDq8csjk8tIP/7aSgFsiYl5ETK9Me1dKaRG0fBmAdesW3bLbl7b/aAbiNoHa22Agf3cOoaWCyTaIiP+LiNsjYtt6BbWMOtqfBuo22RZ4PqX0WKtpA2KbtPvf2xDflRISc3QwbUA1FY+I1YAfA0ellP4KXApsBHwIWETLIaKBYJuU0jhgF+CwiNiu3gH1VEQMA/YAbqhMGqjbpDMD8rsTEScBS4CrKpMWAeunlLYAjgaujog16hVfN9XanwbkNgH2o+2P2AGxTTr431tz1g6mFbtdSkjMC4ExrR6PBp6tUyzLLCJWoGXHuCql9F8AKaXnU0pvppTeAr5HIYeyupJSerZy+wJwIy1xP58P+VRuX6hfhMtkF+DelNLzMHC3SUWtbTDgvjsRMRXYDdg/VU7+VQ77/rlyfx4t52XfU78ou9bJ/jQQt8lQYC/gujxtIGyTjv730iDflRIS8z3AxhGxQaXK2ReYXeeYuqVyXub7wPyU0gWtprc+d/Ep4MH2ry1NRKwaEavn+7Q01HmQlm0xtTLbVOCm+kS4zNpUAANxm7RSaxvMBg6qtDgdD7ySD+OVKCImAscDe6SU/t5q+oiIGFK5vyGwMfBEfaLsnk72p9nAvhGxYkRsQMu6/G55x7eMdgIeTiktzBNK3ya1/vfSIN+Vurc+S9UWc4/S8qvspHrHswxxf5SWwyG/B+6r/O0K/BB4oDJ9NjCy3rF2Y102pKU16f3AQ3k7AOsAc4DHKrdr1zvWbqzLKsCfgXe2mjYgtgktPyYWAf+i5Vf+tFrbgJbDc9+pfG8eAJrqHX8X67GAlvN8+bvy3cq8n67sc/cD9wK71zv+bqxLzf0JOKmyTR4Bdql3/J2tR2X6FcCh7eYtfZvU+t874L4rHf058pckSQUp4VC2JEmqMDFLklQQE7MkSQUxMUuSVBATsyRJBTExS5JUEBOzJEkFMTFLklSQ/weRENPxaWQ0DwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from tensorflow.keras.utils import plot_model\n", "\n", "# Summarize the model\n", "model.summary()\n", "\n", "# Plot the model\n", "plot_model(model, to_file='../images/plot_model.png')\n", "\n", "# Display the image\n", "data = plt.imread('../images/plot_model.png')\n", "plt.imshow(data);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit and evaluate a model\n", "- Basketball Data\n", " - Goal: Predict tournament outcomes\n", " - Data Available: team ratings from the tournament organizers\n", "- Input\n", " - Seed difference (`seed_diff`)\n", "- Output\n", " - Score difference (`score_diff`)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit the model to the tournament basketball data\n", "Now that the model is compiled, you are ready to fit it to some data!\n", "\n", "In this exercise, you'll use a dataset of scores from US College Basketball tournament games. Each row of the dataset has the team ids: `team_1` and `team_2`, as integers. It also has the seed difference between the teams (seeds are assigned by the tournament committee and represent a ranking of how strong the teams are) and the score difference of the game (e.g. if `team_1` wins by 5 points, the score difference is 5).\n", "\n", "To fit the model, you provide a matrix of X variables (in this case one column: the seed difference) and a matrix of Y variables (in this case one column: the score difference)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
seasonteam_1team_2homeseed_diffscore_diffscore_1score_2won
01985288730-3-941500
1198559297304661551
2198598847305-459630
319857328803950411
41985392041001-954630
\n", "
" ], "text/plain": [ " season team_1 team_2 home seed_diff score_diff score_1 score_2 won\n", "0 1985 288 73 0 -3 -9 41 50 0\n", "1 1985 5929 73 0 4 6 61 55 1\n", "2 1985 9884 73 0 5 -4 59 63 0\n", "3 1985 73 288 0 3 9 50 41 1\n", "4 1985 3920 410 0 1 -9 54 63 0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "games_tourney = pd.read_csv('./dataset/games_tourney.csv')\n", "games_tourney.head()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "games_tourney_train, games_tourney_test = train_test_split(games_tourney, test_size=0.3)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "input_tensor = Input(shape=(1, ))\n", "output_tensor = Dense(1)(input_tensor)\n", "\n", "model = Model(input_tensor, output_tensor)\n", "model.compile(optimizer='adam', loss='mean_absolute_error')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "21/21 [==============================] - 0s 8ms/step - loss: 9.5143 - val_loss: 9.5148\n" ] } ], "source": [ "# Now fit the model\n", "model.fit(games_tourney_train['seed_diff'], games_tourney_train['score_diff'],\n", " epochs=1,\n", " batch_size=128,\n", " validation_split=0.1,\n", " verbose=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate the model on a test set\n", "After fitting the model, you can evaluate it on new data. You will give the model a new `X` matrix (also called test data), allow it to make predictions, and then compare to the known `y` variable (also called target data).\n", "\n", "In this case, you'll use data from the post-season tournament to evaluate your model. The tournament games happen after the regular season games you used to train our model, and are therefore a good evaluation of how well your model performs out-of-sample." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9.419431686401367\n" ] } ], "source": [ "# Load the X variable from the test data\n", "X_test = games_tourney_test['seed_diff']\n", "\n", "# Load the y variable from the test data\n", "y_test = games_tourney_test['score_diff']\n", "\n", "# Evaluate the model on the test data\n", "print(model.evaluate(X_test, y_test, verbose=False))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }