{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "USSV_OlCFKOD" }, "source": [ "# Training a neural network on MNIST with Keras\n", "\n", "This simple example demonstrate how to plug TFDS into a Keras model.\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "J8y9ZkLXmAZc" }, "source": [ "Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OGw9EgE0tC0C" }, "source": [ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/datasets/keras_example\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " \u003c/td\u003e\n", " \u003ctd\u003e\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " \u003c/td\u003e\n", "\u003c/table\u003e" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "TTBSvHcSLBzc" }, "outputs": [], "source": [ "import tensorflow.compat.v2 as tf\n", "import nlp\n", "\n", "nlp.disable_progress_bar()\n", "tf.enable_v2_behavior()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VjI6VgOBf0v0" }, "source": [ "## Step 1: Create your input pipeline\n", "\n", "Build efficient input pipeline using advices from:\n", "* [TFDS performance guide](https://www.tensorflow.org/datasets/performances)\n", "* [tf.data performance guide](https://www.tensorflow.org/guide/data_performance#optimize_performance)\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "c3aH3vP_XLI8" }, "source": [ "### Load MNIST\n", "\n", "Load with the following arguments:\n", "\n", "* `shuffle_files`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.\n", "* `as_supervised`: Returns tuple `(img, label)` instead of dict `{'image': img, 'label': label}`" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "ZUMhCXhFXdHQ" }, "outputs": [], "source": [ "(ds_train, ds_test), ds_info = nlp.load(\n", " 'mnist',\n", " split=['train', 'test'],\n", " shuffle_files=True,\n", " as_supervised=True,\n", " with_info=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rgwCFAcWXQTx" }, "source": [ "### Build training pipeline\n", "\n", "Apply the following transormations:\n", "\n", "* `ds.map`: TFDS provide the images as tf.uint8, while the model expect tf.float32, so normalize images\n", "* `ds.cache` As the dataset fit in memory, cache before shuffling for better performance.\u003cbr/\u003e\n", "__Note:__ Random transformations should be applied after caching\n", "* `ds.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.\u003cbr/\u003e\n", "__Note:__ For bigger datasets which do not fit in memory, a standard value is 1000 if your system allows it.\n", "* `ds.batch`: Batch after shuffling to get unique batches at each epoch.\n", "* `ds.prefetch`: Good practice to end the pipeline by prefetching [for performances](https://www.tensorflow.org/guide/data_performance#prefetching)." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "haykx2K9XgiI" }, "outputs": [], "source": [ "def normalize_img(image, label):\n", " \"\"\"Normalizes images: `uint8` -\u003e `float32`.\"\"\"\n", " return tf.cast(image, tf.float32) / 255., label\n", "\n", "ds_train = ds_train.map(\n", " normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", "ds_train = ds_train.cache()\n", "ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n", "ds_train = ds_train.batch(128)\n", "ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RbsMy4X1XVFv" }, "source": [ "### Build evaluation pipeline\n", "\n", "Testing pipeline is similar to the training pipeline, with small differences:\n", "\n", " * No `ds.shuffle()` call\n", " * Caching is done after batching (as batches can be the same between epoch)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "A0KjuDf7XiqY" }, "outputs": [], "source": [ "ds_test = ds_test.map(\n", " normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n", "ds_test = ds_test.batch(128)\n", "ds_test = ds_test.cache()\n", "ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nTFoji3INMEM" }, "source": [ "## Step 2: Create and train the model\n", "\n", "Plug the input pipeline into Keras." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "XWqxdmS1NLKA" }, "outputs": [], "source": [ "model = tf.keras.models.Sequential([\n", " tf.keras.layers.Flatten(input_shape=(28, 28, 1)),\n", " tf.keras.layers.Dense(128,activation='relu'),\n", " tf.keras.layers.Dense(10, activation='softmax')\n", "])\n", "model.compile(\n", " loss='sparse_categorical_crossentropy',\n", " optimizer=tf.keras.optimizers.Adam(0.001),\n", " metrics=['accuracy'],\n", ")\n", "\n", "model.fit(\n", " ds_train,\n", " epochs=6,\n", " validation_data=ds_test,\n", ")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "//learning/deepmind/dm_python:dm_notebook3", "kind": "private" }, "name": "tensorflow/datasets", "private_outputs": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }