{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "fTFj8ft5dlbS" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "lzyBOpYMdp3F" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "m_x4KfSJ7Vt7" }, "outputs": [], "source": [ "#@title MIT License\n", "#\n", "# Copyright (c) 2017 François Chollet\n", "#\n", "# Permission is hereby granted, free of charge, to any person obtaining a\n", "# copy of this software and associated documentation files (the \"Software\"),\n", "# to deal in the Software without restriction, including without limitation\n", "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", "# and/or sell copies of the Software, and to permit persons to whom the\n", "# Software is furnished to do so, subject to the following conditions:\n", "#\n", "# The above copyright notice and this permission notice shall be included in\n", "# all copies or substantial portions of the Software.\n", "#\n", "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", "# DEALINGS IN THE SOFTWARE." ] }, { "cell_type": "markdown", "metadata": { "id": "C9HmC2T4ld5B" }, "source": [ "# Overfit and underfit" ] }, { "cell_type": "markdown", "metadata": { "id": "kRTxFhXAlnl1" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "19rPukKZsPG6" }, "source": [ "As always, the code in this example will use the `tf.keras` API, which you can learn more about in the TensorFlow [Keras guide](https://www.tensorflow.org/guide/keras).\n", "\n", "In both of the previous examples—[classifying text](https://www.tensorflow.org/tutorials/keras/text_classification_with_hub) and [predicting fuel efficiency](https://www.tensorflow.org/tutorials/keras/regression) — we saw that the accuracy of our model on the validation data would peak after training for a number of epochs, and would then stagnate or start decreasing.\n", "\n", "In other words, our model would *overfit* to the training data. Learning how to deal with overfitting is important. Although it's often possible to achieve high accuracy on the *training set*, what we really want is to develop models that generalize well to a *testing set* (or data they haven't seen before).\n", "\n", "The opposite of overfitting is *underfitting*. Underfitting occurs when there is still room for improvement on the train data. This can happen for a number of reasons: If the model is not powerful enough, is over-regularized, or has simply not been trained long enough. This means the network has not learned the relevant patterns in the training data.\n", "\n", "If you train for too long though, the model will start to overfit and learn patterns from the training data that don't generalize to the test data. We need to strike a balance. Understanding how to train for an appropriate number of epochs as we'll explore below is a useful skill.\n", "\n", "To prevent overfitting, the best solution is to use more complete training data. The dataset should cover the full range of inputs that the model is expected to handle. Additional data may only be useful if it covers new and interesting cases.\n", "\n", "A model trained on more complete data will naturally generalize better. When that is no longer possible, the next best solution is to use techniques like regularization. These place constraints on the quantity and type of information your model can store. If a network can only afford to memorize a small number of patterns, the optimization process will force it to focus on the most prominent patterns, which have a better chance of generalizing well.\n", "\n", "In this notebook, we'll explore several common regularization techniques, and use them to improve on a classification model." ] }, { "cell_type": "markdown", "metadata": { "id": "WL8UoOTmGGsL" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "9FklhSI0Gg9R" }, "source": [ "Before getting started, import the necessary packages:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5pZ8A2liqvgk" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "from tensorflow.keras import layers\n", "from tensorflow.keras import regularizers\n", "\n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QnAtAjqRYVXe" }, "outputs": [], "source": [ "!pip install git+https://github.com/tensorflow/docs\n", "\n", "import tensorflow_docs as tfdocs\n", "import tensorflow_docs.modeling\n", "import tensorflow_docs.plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-pnOU-ctX27Q" }, "outputs": [], "source": [ "from IPython import display\n", "from matplotlib import pyplot as plt\n", "\n", "import numpy as np\n", "\n", "import pathlib\n", "import shutil\n", "import tempfile\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jj6I4dvTtbUe" }, "outputs": [], "source": [ "logdir = pathlib.Path(tempfile.mkdtemp())/\"tensorboard_logs\"\n", "shutil.rmtree(logdir, ignore_errors=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "1cweoTiruj8O" }, "source": [ "## The Higgs Dataset\n", "\n", "The goal of this tutorial is not to do particle physics, so don't dwell on the details of the dataset. It contains 11 000 000 examples, each with 28 features, and a binary class label." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YPjAvwb-6dFd" }, "outputs": [], "source": [ "gz = tf.keras.utils.get_file('HIGGS.csv.gz', 'http://mlphysics.ics.uci.edu/data/higgs/HIGGS.csv.gz')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AkiyUdaWIrww" }, "outputs": [], "source": [ "FEATURES = 28" ] }, { "cell_type": "markdown", "metadata": { "id": "SFggl9gYKKRJ" }, "source": [ "The `tf.data.experimental.CsvDataset` class can be used to read csv records directly from a gzip file with no intermediate decompression step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QHz4sLVQEVIU" }, "outputs": [], "source": [ "ds = tf.data.experimental.CsvDataset(gz,[float(),]*(FEATURES+1), compression_type=\"GZIP\")" ] }, { "cell_type": "markdown", "metadata": { "id": "HzahEELTKlSV" }, "source": [ "That csv reader class returns a list of scalars for each record. The following function repacks that list of scalars into a (feature_vector, label) pair." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zPD6ICDlF6Wf" }, "outputs": [], "source": [ "def pack_row(*row):\n", " label = row[0]\n", " features = tf.stack(row[1:],1)\n", " return features, label" ] }, { "cell_type": "markdown", "metadata": { "id": "4oa8tLuwLsbO" }, "source": [ "TensorFlow is most efficient when operating on large batches of data.\n", "\n", "So instead of repacking each row individually make a new `Dataset` that takes batches of 10000-examples, applies the `pack_row` function to each batch, and then splits the batches back up into individual records:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-w-VHTwwGVoZ" }, "outputs": [], "source": [ "packed_ds = ds.batch(10000).map(pack_row).unbatch()" ] }, { "cell_type": "markdown", "metadata": { "id": "lUbxc5bxNSXV" }, "source": [ "Have a look at some of the records from this new `packed_ds`.\n", "\n", "The features are not perfectly normalized, but this is sufficient for this tutorial." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TfcXuv33Fvka" }, "outputs": [], "source": [ "for features,label in packed_ds.batch(1000).take(1):\n", " print(features[0])\n", " plt.hist(features.numpy().flatten(), bins = 101)" ] }, { "cell_type": "markdown", "metadata": { "id": "ICKZRY7gN-QM" }, "source": [ "To keep this tutorial relatively short use just the first 1000 samples for validation, and the next 10 000 for training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hmk49OqZIFZP" }, "outputs": [], "source": [ "N_VALIDATION = int(1e3)\n", "N_TRAIN = int(1e4)\n", "BUFFER_SIZE = int(1e4)\n", "BATCH_SIZE = 500\n", "STEPS_PER_EPOCH = N_TRAIN//BATCH_SIZE" ] }, { "cell_type": "markdown", "metadata": { "id": "FP3M9DmvON32" }, "source": [ "The `Dataset.skip` and `Dataset.take` methods make this easy.\n", "\n", "At the same time, use the `Dataset.cache` method to ensure that the loader doesn't need to re-read the data from the file on each epoch:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H8H_ZzpBOOk-" }, "outputs": [], "source": [ "validate_ds = packed_ds.take(N_VALIDATION).cache()\n", "train_ds = packed_ds.skip(N_VALIDATION).take(N_TRAIN).cache()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zAOqk2_Px7K" }, "outputs": [], "source": [ "train_ds" ] }, { "cell_type": "markdown", "metadata": { "id": "6PMliHoVO3OL" }, "source": [ "These datasets return individual examples. Use the `.batch` method to create batches of an appropriate size for training. Before batching also remember to `.shuffle` and `.repeat` the training set." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y7I4J355O223" }, "outputs": [], "source": [ "validate_ds = validate_ds.batch(BATCH_SIZE)\n", "train_ds = train_ds.shuffle(BUFFER_SIZE).repeat().batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "lglk41MwvU5o" }, "source": [ "## Demonstrate overfitting\n", "\n", "The simplest way to prevent overfitting is to start with a small model: A model with a small number of learnable parameters (which is determined by the number of layers and the number of units per layer). In deep learning, the number of learnable parameters in a model is often referred to as the model's \"capacity\".\n", "\n", "Intuitively, a model with more parameters will have more \"memorization capacity\" and therefore will be able to easily learn a perfect dictionary-like mapping between training samples and their targets, a mapping without any generalization power, but this would be useless when making predictions on previously unseen data.\n", "\n", "Always keep this in mind: deep learning models tend to be good at fitting to the training data, but the real challenge is generalization, not fitting.\n", "\n", "On the other hand, if the network has limited memorization resources, it will not be able to learn the mapping as easily. To minimize its loss, it will have to learn compressed representations that have more predictive power. At the same time, if you make your model too small, it will have difficulty fitting to the training data. There is a balance between \"too much capacity\" and \"not enough capacity\".\n", "\n", "Unfortunately, there is no magical formula to determine the right size or architecture of your model (in terms of the number of layers, or the right size for each layer). You will have to experiment using a series of different architectures.\n", "\n", "To find an appropriate model size, it's best to start with relatively few layers and parameters, then begin increasing the size of the layers or adding new layers until you see diminishing returns on the validation loss.\n", "\n", "Start with a simple model using only `layers.Dense` as a baseline, then create larger versions, and compare them." ] }, { "cell_type": "markdown", "metadata": { "id": "_ReKHdC2EgVu" }, "source": [ "### Training procedure" ] }, { "cell_type": "markdown", "metadata": { "id": "pNzkSkkXSP5l" }, "source": [ "Many models train better if you gradually reduce the learning rate during training. Use `optimizers.schedules` to reduce the learning rate over time:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwQp-ERhAD6F" }, "outputs": [], "source": [ "lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(\n", " 0.001,\n", " decay_steps=STEPS_PER_EPOCH*1000,\n", " decay_rate=1,\n", " staircase=False)\n", "\n", "def get_optimizer():\n", " return tf.keras.optimizers.Adam(lr_schedule)" ] }, { "cell_type": "markdown", "metadata": { "id": "kANLx6OYTQ8B" }, "source": [ "The code above sets a `schedules.InverseTimeDecay` to hyperbolically decrease the learning rate to 1/2 of the base rate at 1000 epochs, 1/3 at 2000 epochs and so on." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HIo_yPjEAFgn" }, "outputs": [], "source": [ "step = np.linspace(0,100000)\n", "lr = lr_schedule(step)\n", "plt.figure(figsize = (8,6))\n", "plt.plot(step/STEPS_PER_EPOCH, lr)\n", "plt.ylim([0,max(plt.ylim())])\n", "plt.xlabel('Epoch')\n", "_ = plt.ylabel('Learning Rate')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ya7x7gr9UjU0" }, "source": [ "Each model in this tutorial will use the same training configuration. So set these up in a reusable way, starting with the list of callbacks.\n", "\n", "The training for this tutorial runs for many short epochs. To reduce the logging noise use the `tfdocs.EpochDots` which simply prints a `.` for each epoch, and a full set of metrics every 100 epochs.\n", "\n", "Next include `callbacks.EarlyStopping` to avoid long and unnecessary training times. Note that this callback is set to monitor the `val_binary_crossentropy`, not the `val_loss`. This difference will be important later.\n", "\n", "Use `callbacks.TensorBoard` to generate TensorBoard logs for the training.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vSv8rfw_T85n" }, "outputs": [], "source": [ "def get_callbacks(name):\n", " return [\n", " tfdocs.modeling.EpochDots(),\n", " tf.keras.callbacks.EarlyStopping(monitor='val_binary_crossentropy', patience=200),\n", " tf.keras.callbacks.TensorBoard(logdir/name),\n", " ]" ] }, { "cell_type": "markdown", "metadata": { "id": "VhctzKhBWVDD" }, "source": [ "Similarly each model will use the same `Model.compile` and `Model.fit` settings:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xRCGwU3YH5sT" }, "outputs": [], "source": [ "def compile_and_fit(model, name, optimizer=None, max_epochs=10000):\n", " if optimizer is None:\n", " optimizer = get_optimizer()\n", " model.compile(optimizer=optimizer,\n", " loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n", " metrics=[\n", " tf.keras.losses.BinaryCrossentropy(\n", " from_logits=True, name='binary_crossentropy'),\n", " 'accuracy'])\n", "\n", " model.summary()\n", "\n", " history = model.fit(\n", " train_ds,\n", " steps_per_epoch = STEPS_PER_EPOCH,\n", " epochs=max_epochs,\n", " validation_data=validate_ds,\n", " callbacks=get_callbacks(name),\n", " verbose=0)\n", " return history" ] }, { "cell_type": "markdown", "metadata": { "id": "mxBeiLUiWHJV" }, "source": [ "### Tiny model" ] }, { "cell_type": "markdown", "metadata": { "id": "a6JDv12scLTI" }, "source": [ "Start by training a model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EZh-QFjKHb70" }, "outputs": [], "source": [ "tiny_model = tf.keras.Sequential([\n", " layers.Dense(16, activation='elu', input_shape=(FEATURES,)),\n", " layers.Dense(1)\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X72IUdWYipIS" }, "outputs": [], "source": [ "size_histories = {}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bdOcJtPGHhJ5" }, "outputs": [], "source": [ "size_histories['Tiny'] = compile_and_fit(tiny_model, 'sizes/Tiny')" ] }, { "cell_type": "markdown", "metadata": { "id": "rS_QGT6icwdI" }, "source": [ "Now check how the model did:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dkEvb2x5XsjE" }, "outputs": [], "source": [ "plotter = tfdocs.plots.HistoryPlotter(metric = 'binary_crossentropy', smoothing_std=10)\n", "plotter.plot(size_histories)\n", "plt.ylim([0.5, 0.7])" ] }, { "cell_type": "markdown", "metadata": { "id": "LGxGzh_FWOJ8" }, "source": [ "### Small model" ] }, { "cell_type": "markdown", "metadata": { "id": "YjMb6E72f2pN" }, "source": [ "To see if you can beat the performance of the small model, progressively train some larger models.\n", "\n", "Try two hidden layers with 16 units each:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QKgdXPx9usBa" }, "outputs": [], "source": [ "small_model = tf.keras.Sequential([\n", " # `input_shape` is only required here so that `.summary` works.\n", " layers.Dense(16, activation='elu', input_shape=(FEATURES,)),\n", " layers.Dense(16, activation='elu'),\n", " layers.Dense(1)\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LqG3MXF5xSjR" }, "outputs": [], "source": [ "size_histories['Small'] = compile_and_fit(small_model, 'sizes/Small')" ] }, { "cell_type": "markdown", "metadata": { "id": "L-DGRBbGxI6G" }, "source": [ "### Medium model" ] }, { "cell_type": "markdown", "metadata": { "id": "SrfoVQheYSO5" }, "source": [ "Now try 3 hidden layers with 64 units each:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jksi-XtaxDAh" }, "outputs": [], "source": [ "medium_model = tf.keras.Sequential([\n", " layers.Dense(64, activation='elu', input_shape=(FEATURES,)),\n", " layers.Dense(64, activation='elu'),\n", " layers.Dense(64, activation='elu'),\n", " layers.Dense(1)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "jbngCZliYdma" }, "source": [ "And train the model using the same data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ofn1AwDhx-Fe" }, "outputs": [], "source": [ "size_histories['Medium'] = compile_and_fit(medium_model, \"sizes/Medium\")" ] }, { "cell_type": "markdown", "metadata": { "id": "vIPuf23FFaVn" }, "source": [ "### Large model\n", "\n", "As an exercise, you can create an even larger model, and see how quickly it begins overfitting. Next, let's add to this benchmark a network that has much more capacity, far more than the problem would warrant:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ghQwwqwqvQM9" }, "outputs": [], "source": [ "large_model = tf.keras.Sequential([\n", " layers.Dense(512, activation='elu', input_shape=(FEATURES,)),\n", " layers.Dense(512, activation='elu'),\n", " layers.Dense(512, activation='elu'),\n", " layers.Dense(512, activation='elu'),\n", " layers.Dense(1)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "D-d-i5DaYmr7" }, "source": [ "And, again, train the model using the same data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U1A99dhqvepf" }, "outputs": [], "source": [ "size_histories['large'] = compile_and_fit(large_model, \"sizes/large\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Fy3CMUZpzH3d" }, "source": [ "### Plot the training and validation losses" ] }, { "cell_type": "markdown", "metadata": { "id": "HSlo1F4xHuuM" }, "source": [ "The solid lines show the training loss, and the dashed lines show the validation loss (remember: a lower validation loss indicates a better model)." ] }, { "cell_type": "markdown", "metadata": { "id": "OLhL1AszdLfM" }, "source": [ "While building a larger model gives it more power, if this power is not constrained somehow it can easily overfit to the training set.\n", "\n", "In this example, typically, only the `\"Tiny\"` model manages to avoid overfitting altogether, and each of the larger models overfit the data more quickly. This becomes so severe for the `\"large\"` model that you need to switch the plot to a log-scale to really see what's happening.\n", "\n", "This is apparent if you plot and compare the validation metrics to the training metrics.\n", "\n", "* It's normal for there to be a small difference.\n", "* If both metrics are moving in the same direction, everything is fine.\n", "* If the validation metric begins to stagnate while the training metric continues to improve, you are probably close to overfitting.\n", "* If the validation metric is going in the wrong direction, the model is clearly overfitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0XmKDtOWzOpk" }, "outputs": [], "source": [ "plotter.plot(size_histories)\n", "a = plt.xscale('log')\n", "plt.xlim([5, max(plt.xlim())])\n", "plt.ylim([0.5, 0.7])\n", "plt.xlabel(\"Epochs [Log Scale]\")" ] }, { "cell_type": "markdown", "metadata": { "id": "UekcaQdmZxnW" }, "source": [ "Note: All the above training runs used the `callbacks.EarlyStopping` to end the training once it was clear the model was not making progress." ] }, { "cell_type": "markdown", "metadata": { "id": "DEQNKadHA0M3" }, "source": [ "### View in TensorBoard\n", "\n", "These models all wrote TensorBoard logs during training.\n", "\n", "Open an embedded TensorBoard viewer inside a notebook:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6oa1lkJddZ-m" }, "outputs": [], "source": [ "#docs_infra: no_execute\n", "\n", "# Load the TensorBoard notebook extension\n", "%load_ext tensorboard\n", "\n", "# Open an embedded TensorBoard viewer\n", "%tensorboard --logdir {logdir}/sizes" ] }, { "cell_type": "markdown", "metadata": { "id": "fjqx3bywDPjf" }, "source": [ "You can view the [results of a previous run](https://tensorboard.dev/experiment/vW7jmmF9TmKmy3rbheMQpw/#scalars&_smoothingWeight=0.97) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).\n", "\n", "TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.\n", "\n", "It's also included in an `