{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "USSV_OlCFKOD"
      },
      "source": [
        "# Training a neural network on MNIST with Keras\n",
        "\n",
        "This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J8y9ZkLXmAZc"
      },
      "source": [
        "Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OGw9EgE0tC0C"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/datasets/keras_example\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TTBSvHcSLBzc"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VjI6VgOBf0v0"
      },
      "source": [
        "## Step 1: Create your input pipeline\n",
        "\n",
        "Start by building an efficient input pipeline using advices from:\n",
        "* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide\n",
        "* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c3aH3vP_XLI8"
      },
      "source": [
        "### Load a dataset\n",
        "\n",
        "Load the MNIST dataset with the following arguments:\n",
        "\n",
        "* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.\n",
        "* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZUMhCXhFXdHQ"
      },
      "outputs": [],
      "source": [
        "(ds_train, ds_test), ds_info = tfds.load(\n",
        "    'mnist',\n",
        "    split=['train', 'test'],\n",
        "    shuffle_files=True,\n",
        "    as_supervised=True,\n",
        "    with_info=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rgwCFAcWXQTx"
      },
      "source": [
        "### Build a training pipeline\n",
        "\n",
        "Apply the following transformations:\n",
        "\n",
        "* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.\n",
        "* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.\u003cbr/\u003e\n",
        "__Note:__ Random transformations should be applied after caching.\n",
        "* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.\u003cbr/\u003e\n",
        "__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.\n",
        "* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.\n",
        "* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "haykx2K9XgiI"
      },
      "outputs": [],
      "source": [
        "def normalize_img(image, label):\n",
        "  \"\"\"Normalizes images: `uint8` -\u003e `float32`.\"\"\"\n",
        "  return tf.cast(image, tf.float32) / 255., label\n",
        "\n",
        "ds_train = ds_train.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "ds_train = ds_train.cache()\n",
        "ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n",
        "ds_train = ds_train.batch(128)\n",
        "ds_train = ds_train.prefetch(tf.data.AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RbsMy4X1XVFv"
      },
      "source": [
        "### Build an evaluation pipeline\n",
        "\n",
        "Your testing pipeline is similar to the training pipeline with small differences:\n",
        "\n",
        " * You don't need to call `tf.data.Dataset.shuffle`.\n",
        " * Caching is done after batching because batches can be the same between epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A0KjuDf7XiqY"
      },
      "outputs": [],
      "source": [
        "ds_test = ds_test.map(\n",
        "    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n",
        "ds_test = ds_test.batch(128)\n",
        "ds_test = ds_test.cache()\n",
        "ds_test = ds_test.prefetch(tf.data.AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nTFoji3INMEM"
      },
      "source": [
        "## Step 2: Create and train the model\n",
        "\n",
        "Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XWqxdmS1NLKA"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "model.compile(\n",
        "    optimizer=tf.keras.optimizers.Adam(0.001),\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
        ")\n",
        "\n",
        "model.fit(\n",
        "    ds_train,\n",
        "    epochs=6,\n",
        "    validation_data=ds_test,\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "tensorflow/datasets",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}