{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.1\n", "IPython 6.1.0\n", "\n", "tensorflow 1.1.0\n", "numpy 1.12.1\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p tensorflow,numpy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Input Pipelines to Read Data from TFRecords Files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorFlow provides users with multiple options for providing data to the model. One of the probably most common methods is to define placeholders in the TensorFlow graph and feed the data from the current Python session into the TensorFlow `Session` using the `feed_dict` parameter. Using this approach, a large dataset that does not fit into memory is most conveniently and efficiently stored using NumPy archives as explained in [Chunking an Image Dataset for Minibatch Training using NumPy NPZ Archives](image-data-chunking-npz.ipynb) or HDF5 data base files ([Storing an Image Dataset for Minibatch Training using HDF5](image-data-chunking-hdf5.ipynb)).\n", "\n", "Another approach, which is often preferred when it comes to computational efficiency, is to do the \"data loading\" directly in the graph using input queues from so-called TFRecords files, which will be illustrated in this notebook.\n", "\n", "Beyond the examples in this notebook, you are encouraged to read more in TensorFlow's \"[Reading Data](https://www.tensorflow.org/programmers_guide/reading_data)\" guide." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 0. The Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's pretend we have a directory of images containing two subdirectories with images for training, validation, and testing. The following function will create such a dataset of images in JPEG format locally for demonstration purposes." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./train-images-idx3-ubyte.gz\n", "Extracting ./train-labels-idx1-ubyte.gz\n", "Extracting ./t10k-images-idx3-ubyte.gz\n", "Extracting ./t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "# Note that executing the following code \n", "# cell will download the MNIST dataset\n", "# and save all the 60,000 images as separate JPEG\n", "# files. This might take a few minutes depending\n", "# on your machine.\n", "\n", "import numpy as np\n", "\n", "# load utilities from ../helper.py\n", "import sys\n", "sys.path.insert(0, '..') \n", "from helper import mnist_export_to_jpg\n", "\n", "np.random.seed(123)\n", "mnist_export_to_jpg(path='./')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `mnist_export_to_jpg` function called above creates 3 directories, mnist_train, mnist_test, and mnist_validation. Note that the names of the subdirectories correspond directly to the class label of the images that are stored under it:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mnist_train subdirectories ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n", "mnist_valid subdirectories ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n", "mnist_test subdirectories ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n" ] } ], "source": [ "import os\n", "\n", "for i in ('train', 'valid', 'test'): \n", " dirs = [d for d in os.listdir('mnist_%s' % i) if not d.startswith('.')]\n", " print('mnist_%s subdirectories' % i, dirs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make sure that the images look okay, the snippet below plots an example image from the subdirectory `mnist_train/9/`:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEkpJREFUeJzt3W2MVGWWB/D/4cUGupGX7V5AB22MZlQwMkmFbMSY0XFG\nxkyCYxSHmJE1ZtoP7GQnmZhVNsbXGLPuQPywkjALGdjMOrMRVDS6K5AJZgKZ0BBWcFyVxcYBge4G\nku7mRWj67Ie+uK32PaetW3VvFef/SwjddepWPV1d/67qPvd5HlFVEFE8o4oeABEVg+EnCorhJwqK\n4ScKiuEnCorhJwqK4ScKiuEnCorhJwpqTJ531tzcrK2trXneJVFVeGfGikhOI/myjo4OdHd3j+jO\nM4VfRBYAeBHAaAD/qqrPW9dvbW1Fe3t7lrusS+fPnzfr3hNl1KjqvUEr8klc7fu2bj/rbff395v1\nMWNyfV39QqlUGvF1y35WichoAP8C4IcArgewWESuL/f2iChfWV5S5gHYp6r7VfUsgN8BWFiZYRFR\ntWUJ/+UA/jLk84PJZV8iIm0i0i4i7V1dXRnujogqqep/7VfVVapaUtVSS0tLte+OiEYoS/gPAZg5\n5PNvJZcRUR3IEv4dAK4RkVkicgmAnwDYWJlhEVG1ld2PUNV+Efk7AP+FwVbfGlV9v2Iju4iMHj06\n0/Feq3BgYCC1NnbsWPPYrC2vc+fOmXXr/k+ePGke642tsbGx7OO9x9RrrxbVyqukTF+Bqr4F4K0K\njYWIcsTTe4mCYviJgmL4iYJi+ImCYviJgmL4iYKq/2ZlHejr6zPrTU1NZt07T8Cqnzp1KtNtNzQ0\nmHXvPALrHATv686qt7c3teZ9XVnPzagHfOUnCorhJwqK4ScKiuEnCorhJwqK4ScKiq2+HGRtaWVZ\n5XbChAmZ7tvjTY21xt7T01P2sQAwZcoUsz5x4kSzbjlz5oxZ91qBXgu0FvCVnygohp8oKIafKCiG\nnygohp8oKIafKCiGnygo9vlz4PXCvR1fvZ5ylmWkT58+bda9Xrs3NdYa2+TJk81jPWfPnjXrVq/d\nWxZ83LhxZY2pnvCVnygohp8oKIafKCiGnygohp8oKIafKCiGnyioTH1+EekA0AvgPIB+VS1VYlAX\nm1peBtrr03tbVXusefHV7qVn2X7cO//BG3vWrc/zUImTfG5V1e4K3A4R5Yhv+4mCyhp+BbBZRHaK\nSFslBkRE+cj6tv9mVT0kIn8NYJOI/I+qvjv0CskPhTYAuOKKKzLeHRFVSqZXflU9lPzfCeBVAPOG\nuc4qVS2paqmlpSXL3RFRBZUdfhFpFJGJFz4G8AMAeys1MCKqrixv+6cBeDVpaYwB8O+q+p8VGRUR\nVV3Z4VfV/QBurOBYLlrnzp0z69Vc4/3AgQNmffPmzWZ99erVZn379u1m3Vo739pCeyQeeeQRs/7U\nU0+l1saPH28em2WvhHrBVh9RUAw/UVAMP1FQDD9RUAw/UVAMP1FQXLo7B1mW1gaA48ePm/WXXnop\ntbZ27Vrz2P3795t1rw3Z2Nho1q12nnfGZ1dXl1l/4YUXzLq1JPrjjz9uHutt/+0tx17L07gv4Cs/\nUVAMP1FQDD9RUAw/UVAMP1FQDD9RUAw/UVC59vlV1Zze6vWUrd6q11c9deqUWZ8wYYJZt2SdsvvZ\nZ5+Z9dtuu82sf/jhh2bd4m2Tfeutt5r1+fPnm/XZs2eXfaw33fjuu+826xs2bEitPfPMM+ax3nRj\na6pyveArP1FQDD9RUAw/UVAMP1FQDD9RUAw/UVAMP1FQufb5RcTseVvzrwF7XvzZs2fNY7NuB33y\n5MnUmjenfceOHWb9jjvuMOve12Z57LHHzPqyZcvMelNTk1n3xmbVvdv2eI+7de7HsWPHzGO9reWK\nXI69UvjKTxQUw08UFMNPFBTDTxQUw08UFMNPFBTDTxSU2+cXkTUAfgSgU1XnJJdNBfB7AK0AOgAs\nUtUTWQdz5swZs271hS+55JJM9+31q62e8u7du81jb7nlFrPures/apT9M3rTpk2ptdtvv908tru7\n26x7vXjvcc/yfdm6datZP336tFm/7rrrUmveuvyeeujje0byyv8bAAu+ctmjALao6jUAtiSfE1Ed\nccOvqu8C+OqWMQsBXNgKZi2Auyo8LiKqsnJ/55+mqoeTj48AmFah8RBRTjL/wU9VFYCm1UWkTUTa\nRaTd23uNiPJTbviPisgMAEj+70y7oqquUtWSqpa8jRmJKD/lhn8jgCXJx0sAvF6Z4RBRXtzwi8jL\nALYD+LaIHBSRhwA8D+D7IvIxgNuTz4mojrh9flVdnFL6XoXH4vazBwYGyj7Wm3/t9dqt459++mnz\nWG9PAe++vX73DTfcYNYtzc3NZr2np8esX3rppWXf97Zt28z6+vXrzbr1fACAGTNmpNa8dfe98z6y\nnldSC3iGH1FQDD9RUAw/UVAMP1FQDD9RUAw/UVC5Lt0NDG7TncZbXts6NiuvVfjaa6+l1t544w3z\nWG/654MPPmjWr732WrMuIqm1LMuhA/7W5d735MSJ9Jnezz77rHnswYMHzbo33dg6o9Qbt9fK6+vr\nM+tZlyXPA1/5iYJi+ImCYviJgmL4iYJi+ImCYviJgmL4iYLKvc9v8Xrtlt7eXrPuTeH0vPLKK6k1\nr1c+efJks758+fKyxjQSnZ2piywBAC677DKz7n1t3nkETzzxRGrt7bffNo9taGgw695U6fvuuy+1\n5k3Z9e7bq9cDvvITBcXwEwXF8BMFxfATBcXwEwXF8BMFxfATBVVT8/mteekebxnnrKZPn55aGz9+\nvHmsN3f8k08+MeuzZs0y6xavj+9ti+6tsfDAAw+YdW+tA4v3fPAe1xtvvLHs+/a2//a+5/WAr/xE\nQTH8REEx/ERBMfxEQTH8REEx/ERBMfxEQbl9fhFZA+BHADpVdU5y2ZMAfgagK7naMlV9y7stVc20\nzbY1B3vSpEne3Zu8ddjvvffe1NqKFSsy3XepVDLr8+fPN+v33HNPam327NnmsevWrTPrK1euNOve\n1udz5sxJrXlr2+/Zs8ese+skWGPz9lLw1gq4GIzklf83ABYMc/kKVZ2b/HODT0S1xQ2/qr4L4HgO\nYyGiHGX5nf/nIvKeiKwRkSkVGxER5aLc8K8EcBWAuQAOA/hV2hVFpE1E2kWkvbu7u8y7I6JKKyv8\nqnpUVc+r6gCAXwOYZ1x3laqWVLXU3Nxc7jiJqMLKCr+IzBjy6Y8B7K3McIgoLyNp9b0M4LsAmkXk\nIIAnAHxXROYCUAAdAB6u4hiJqArc8Kvq4mEuXl2FsbisPdO9frO3/rzXc7766qtTa9u3bzePXbp0\nqVnftWuXWffmxL/zzjupNW/Ou7d+vTdv/eGH7Z/7zz33XFk1ANi3b59Z99Y5yLI+hPVcA4Dz58+b\n9Xo4T4Bn+BEFxfATBcXwEwXF8BMFxfATBcXwEwVVU1t0V1OWtg8AtLS0pNYaGxvNY3fu3GnW33zz\nTbPe0dFh1q1W36effmoee//995t1ayozALS2tpr1EydOpNa2bdtmHuudDn7TTTeZda+9a/Fax96U\n4HrAV36ioBh+oqAYfqKgGH6ioBh+oqAYfqKgGH6ioHLt84tIpt6rxZtC2d/fb9azjGvChAllHwsA\nCxYMtzjy//PG1tbWllrzpvQ2NDSYdY93+/v370+teUtzexYtWlT2sd64vfrFgK/8REEx/ERBMfxE\nQTH8REEx/ERBMfxEQTH8REFdNPP5ve29vfn83hLW1vFZ53Z7y0B7Y/eWmc7iyJEjZn369Olm3Vp2\n/OTJk+ax1vbeADB37lyzbvXqvT6+95haW80D/vOxFtT+CImoKhh+oqAYfqKgGH6ioBh+oqAYfqKg\nGH6ioNw+v4jMBLAOwDQACmCVqr4oIlMB/B5AK4AOAItUNX2R9hHw+t1W79TrhWftlVdzfvepU6fM\n+pQpU8q+7WPHjpn1qVOnmnWvj+/16rdu3Zpa83rlV155pVn31miw1njwjvWeD/XQx/eM5CvoB/BL\nVb0ewN8AWCoi1wN4FMAWVb0GwJbkcyKqE274VfWwqu5KPu4F8AGAywEsBLA2udpaAHdVa5BEVHnf\n6L2LiLQC+A6APwGYpqqHk9IRDP5aQER1YsThF5EmAOsB/EJVe4bWdPAX4mF/KRaRNhFpF5H2rq6u\nTIMlosoZUfhFZCwGg/9bVd2QXHxURGYk9RkAOoc7VlVXqWpJVUvWZpdElC83/DL4Z/LVAD5Q1eVD\nShsBLEk+XgLg9coPj4iqZSRTeucD+CmAPSKyO7lsGYDnAfyHiDwE4ACA8tdRTnjtE6sV6B2btTWT\ndYtvS5ZWnmf8+PFmPevX5W1Pbn3Pxo0bZx7rtdtOnz5t1idOnJha87bgruY06Vrhhl9V/wgg7Rny\nvcoOh4jyUv9nKhBRWRh+oqAYfqKgGH6ioBh+oqAYfqKgamrpbq/nnGWKpjdd2Nvi2+JNTfXqXk/5\nzJkzZt0au7d9uNfv9qbsTpo0yaxbS6J73zNvGrV3DoPF+3573zNvbFmeT3nhKz9RUAw/UVAMP1FQ\nDD9RUAw/UVAMP1FQDD9RULn2+QcGBsw52FnmnmfdJjsLb60Aryf8+eefm3Vv3rvF23rce9wmT55s\n1g8cOGDWP/roo9Sad/6Cdx7AmDH209d63LM8poA/tnrAV36ioBh+oqAYfqKgGH6ioBh+oqAYfqKg\nGH6ioHLt848aNSrTHOx6lXV78Cz97qzrz/f19Zn1hoYGs27N9z9xwt7R3dsTwGM97lnPIfDq9YCv\n/ERBMfxEQTH8REEx/ERBMfxEQTH8REEx/ERBuc1KEZkJYB2AaQAUwCpVfVFEngTwMwBdyVWXqepb\n1RroxczbU8BbD8BaY95ba8DjnZfhrU9v9eqnT59uHjtz5kyznmUvhqx9em+dhKznV+RhJI9AP4Bf\nquouEZkIYKeIbEpqK1T1n6s3PCKqFjf8qnoYwOHk414R+QDA5dUeGBFV1zd6TygirQC+A+BPyUU/\nF5H3RGSNiExJOaZNRNpFpL2rq2u4qxBRAUYcfhFpArAewC9UtQfASgBXAZiLwXcGvxruOFVdpaol\nVS21tLRUYMhEVAkjCr+IjMVg8H+rqhsAQFWPqup5VR0A8GsA86o3TCKqNDf8Mjg1ajWAD1R1+ZDL\nZwy52o8B7K388IioWkby1/75AH4KYI+I7E4uWwZgsYjMxWD7rwPAw1UZYQBeOy7rdtFZeK08rxW4\nd2/6a0JPT495rLe8tjc2a0l0r9Xn3XbWFmotGMlf+/8IYLiJ0ezpE9Wx+v/xRURlYfiJgmL4iYJi\n+ImCYviJgmL4iYKq//WHLwJezzjLFuDeOQLesuJePYumpiaz7p2/4E3pterekuMeLt1NRHWL4ScK\niuEnCorhJwqK4ScKiuEnCorhJwpKqjkX/Gt3JtIF4MCQi5oBdOc2gG+mVsdWq+MCOLZyVXJsV6rq\niNbLyzX8X7tzkXZVLRU2AEOtjq1WxwVwbOUqamx8208UFMNPFFTR4V9V8P1banVstTougGMrVyFj\nK/R3fiIqTtGv/ERUkELCLyILRORDEdknIo8WMYY0ItIhIntEZLeItBc8ljUi0ikie4dcNlVENonI\nx8n/w26TVtDYnhSRQ8ljt1tE7ixobDNF5A8i8mcReV9E/j65vNDHzhhXIY9b7m/7RWQ0gI8AfB/A\nQQA7ACxW1T/nOpAUItIBoKSqhfeEReQWAH0A1qnqnOSyfwJwXFWfT35wTlHVf6iRsT0JoK/onZuT\nDWVmDN1ZGsBdAP4WBT52xrgWoYDHrYhX/nkA9qnqflU9C+B3ABYWMI6ap6rvAjj+lYsXAlibfLwW\ng0+e3KWMrSao6mFV3ZV83Avgws7ShT52xrgKUUT4LwfwlyGfH0RtbfmtADaLyE4RaSt6MMOYlmyb\nDgBHAEwrcjDDcHduztNXdpaumceunB2vK41/8Pu6m1V1LoAfAliavL2tSTr4O1sttWtGtHNzXobZ\nWfoLRT525e54XWlFhP8QgJlDPv9WcllNUNVDyf+dAF5F7e0+fPTCJqnJ/50Fj+cLtbRz83A7S6MG\nHrta2vG6iPDvAHCNiMwSkUsA/ATAxgLG8TUi0pj8IQYi0gjgB6i93Yc3AliSfLwEwOsFjuVLamXn\n5rSdpVHwY1dzO16rau7/ANyJwb/4/y+AfyxiDCnjugrAfyf/3i96bABexuDbwHMY/NvIQwD+CsAW\nAB8D2Axgag2N7d8A7AHwHgaDNqOgsd2Mwbf07wHYnfy7s+jHzhhXIY8bz/AjCop/8CMKiuEnCorh\nJwqK4ScKiuEnCorhJwqK4ScKiuEnCur/ABj13FTNxtqQAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt\n", "import os\n", "\n", "some_img = os.path.join('./mnist_train/9/', os.listdir('./mnist_train/9/')[0])\n", "\n", "img = mpimg.imread(some_img)\n", "print(img.shape)\n", "plt.imshow(img, cmap='binary');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: The JPEG format introduces a few artifacts that we can see in the image above. In this case, we use JPEG instead of PNG. Here, JPEG is used for demonstration purposes since that's still format many image datasets are stored in." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Saving images as TFRecords files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we are going to convert the images into a binary TFRecords file, which is based on Google's [protocol buffer](https://developers.google.com/protocol-buffers/) format:\n", "\n", "> The recommended format for TensorFlow is a TFRecords file containing tf.train.Example protocol buffers (which contain Features as a field). You write a little program that gets your data, stuffs it in an Example protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the tf.python_io.TFRecordWriter. For example, tensorflow/examples/how_tos/reading_data/convert_to_records.py converts MNIST data to this format. \n", "\n", "> [ Excerpt from https://www.tensorflow.org/programmers_guide/reading_data ]\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import glob\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "\n", "def images_to_tfrecords(data_stempath='./mnist_',\n", " shuffle=False, \n", " random_seed=None):\n", " \n", " def int64_to_feature(value):\n", " return tf.train.Feature(int64_list=tf.train.Int64List(value=value))\n", " \n", " for s in ['train', 'valid', 'test']:\n", "\n", " with tf.python_io.TFRecordWriter('mnist_%s.tfrecords' % s) as writer:\n", "\n", " img_paths = np.array([p for p in glob.iglob('%s%s/**/*.jpg' % \n", " (data_stempath, s), \n", " recursive=True)])\n", "\n", " if shuffle:\n", " rng = np.random.RandomState(random_seed)\n", " rng.shuffle(img_paths)\n", "\n", " for idx, path in enumerate(img_paths):\n", " label = int(os.path.basename(os.path.dirname(path)))\n", " image = mpimg.imread(path)\n", " image = image.reshape(-1).tolist()\n", "\n", "\n", " example = tf.train.Example(features=tf.train.Features(feature={\n", " 'image': int64_to_feature(image),\n", " 'label': int64_to_feature([label])}))\n", "\n", " writer.write(example.SerializeToString())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that it is important to shuffle the dataset so that we can later make use of TensorFlow's [`tf.train.shuffle_batch`](https://www.tensorflow.org/api_docs/python/tf/train/shuffle_batch) function and don't need to load the whole dataset into memory to shuffle epochs." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "images_to_tfrecords(shuffle=True, random_seed=123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just to make sure that the images were serialized correctly, let us load an image back from TFRecords using the [`tf.python_io.tf_record_iterator`](https://www.tensorflow.org/api_docs/python/tf/python_io/tf_record_iterator) and display it:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label: 2\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEwpJREFUeJzt3W+MVGWWBvDnAI3yH9lusBG0BzQmhGQZU5JFiXF1Z1Ay\nEYn/BuMGybCMCZKFzIc1GCMxfjCrgyFhM4iKwsSVWWUQRLMbIRvYSQxaoCs6rguaHgGhaQIREASb\nPvuhL26jfc8p6lbVreY8v4TQXadv3ZfqfqiuOvd9X1FVEFE8ffIeABHlg+EnCorhJwqK4ScKiuEn\nCorhJwqK4ScKiuEnCorhJwqqXy1P1tjYqC0tLan1LFcbikjZx9a7en5cvLFV8/z1/LjkpbW1FYcP\nHy7pH5cp/CJyG4BlAPoCeEFVn7K+vqWlBcViMbV+5swZ83zWN/uSSy4xj/Xk+UPs+e6778y6NfaG\nhgbz2Kz/Lm9s1vm9x9yrd3Z2mnVLv37Znve8c/fpk88v1YVCoeSvLXuEItIXwL8AuB3ABACzRGRC\nufdHRLWV5b+nyQD2qOoXqnoGwFoAMyozLCKqtizhvwLA3m6f70tuO4+IzBORoogU29vbM5yOiCqp\n6i9MVHWlqhZUtdDU1FTt0xFRibKEfz+Asd0+H5PcRkS9QJbwvw/gGhH5iYj0B/BLABsrMywiqray\n+x2q2iEiDwP4D3S1+lap6ifOMWY7r3///uUOJ7OOjo6yj/XaOl4L02vHefUsvLF5LTFvbNb9e8d6\nj6tXP3v2bGrtxIkTme47a2u5HmRqdqrq2wDertBYiKiGeHkvUVAMP1FQDD9RUAw/UVAMP1FQDD9R\nUDWdz19NVk+3FH379jXrWaZoDhgwwKx7U1e9Xrw1rda776xTWz1Zrt04efKkWfd67db3dPDgwWWN\n6WLCZ36ioBh+oqAYfqKgGH6ioBh+oqAYfqKgatrqE5FMrR+rpeW1+i699NKyz+vdv9cm9MbmHe89\nZtYKvNWcDgxkmxo7cOBA81ivnsXp06fN+rfffmvWBw0aZNar3UKtBD7zEwXF8BMFxfATBcXwEwXF\n8BMFxfATBcXwEwVVV81Irx9u9Yyz9rO9XVetvq937movSW71+b1+tteP9q5BqObUWK/X7i23bo3N\nmw58MSzN7eEzP1FQDD9RUAw/UVAMP1FQDD9RUAw/UVAMP1FQmfr8ItIK4DiAswA6VLWQ5f68Pn+W\nfnnWpb29+dtZHD582Kx7S3c3Nzen1rLOK7fWUACyXV9R7TUYrLF7S7Fb104A/joGQ4cONev1oBIX\n+fytqto/vURUd/hrP1FQWcOvADaLyA4RmVeJARFRbWT9tX+qqu4XkZEA3hGR/1HVbd2/IPlPYR4A\nXHnllRlPR0SVkumZX1X3J38fArAewOQevmalqhZUtdDU1JTldERUQWWHX0QGiciQcx8D+DmAjys1\nMCKqriy/9o8CsD5pifQD8K+q+u8VGRURVV3Z4VfVLwD8dQXH4m4nnYXXlx02bFjZ993W1mbWvT79\nm2++adbfeOMNs753797UmndthLd9uNfv9ub7Hzt2LLU2ceJE89iFCxea9cmTf/Qq8zzWNQjWuAC/\nT98b+vgetvqIgmL4iYJi+ImCYviJgmL4iYJi+ImC6lVLd2eRdWrr0aNHU2uvvPKKeeySJUvMurds\n+DfffGPWLZdffrlZHzt2rFn3WmK7d+8261YrcNeuXeaxXjtt9OjRZn3MmDGptaxLjnvLimedjlwL\nfOYnCorhJwqK4ScKiuEnCorhJwqK4ScKiuEnCqqmfX5VNZdTHjhwoHm81e/2ltb2pp56S1SvWLEi\ntbZ48WLzWK9fPXPmTLN+/fXXm/W77747tTZy5EjzWO/6B+8ag23btpn1/fv3p9ZeeOEF89iXXnrJ\nrG/dutWsb9iwIbU2fvx481jPxbCFN5/5iYJi+ImCYviJgmL4iYJi+ImCYviJgmL4iYKqaZ9fRMx+\nu7fEtdXL7+joMI/15lcfPHjQrK9Zsya1Zs0bB4Ann3zSrE+bNs2se3PyLd6S5d7j4l0/MWXKFLNu\n9cPnzp1rHustp97a2mrW161bl1pbtGiReax33UdvmK/v4TM/UVAMP1FQDD9RUAw/UVAMP1FQDD9R\nUAw/UVBun19EVgH4BYBDqjoxuW0EgD8AaAHQCuBeVU1f2P78+0uteXPurV6+t+a/N2/d66W/9tpr\nZd/3tddea9a9bbC9f1uWxzTrfgZeL/7kyZNl3/eyZcvM+pw5c8z69u3bU2t9+tjPe97W5heDUp75\nXwZw2w9uewTAFlW9BsCW5HMi6kXc8KvqNgBHfnDzDACrk49XA7izwuMioior9zX/KFU9kHx8EMCo\nCo2HiGok8xt+qqoANK0uIvNEpCgixfb29qynI6IKKTf8bSLSDADJ34fSvlBVV6pqQVULTU1NZZ6O\niCqt3PBvBDA7+Xg2gPRlUomoLrnhF5FXAbwL4FoR2ScivwLwFICfichuAH+XfE5EvYjb5FXVWSml\nW8s5obUXvdeTtvrd3jrq3vxsr+87ceJEs27x1inoetskndeLt8Y+YMCATOf29qH37t9aD8D7ntxw\nww2Zzm2t2+/18U+dOpXp3L0Br/AjCorhJwqK4ScKiuEnCorhJwqK4ScKqqZLdwN2O89bfjvL9NOG\nhgaz7m1FbY3Nm9Za7emhx44dS615LVCvbrVmS2EtHe4tC+5t0e214xYsWGDWLV4rz5uq7G03Xw/4\nzE8UFMNPFBTDTxQUw08UFMNPFBTDTxQUw08UVM37/Bavj2/1dbNOXfV6zhZvebLPPvvMrHv/bm/q\n66ZNm1JrjY2N5rFTp04168OHDzfrLS0tZt36vhw58sN1Yc+3dOlSs+5tk21Nw/aWQ/emYXvXR/QG\nfOYnCorhJwqK4ScKiuEnCorhJwqK4ScKiuEnCqqmfX5VNXvW3px7b3ltizf325vPv2XLltTaihUr\nzGO3bt1q1r1lwQuFgllfu3Ztas1betvrlXvLqXtjmzBhQmrt3XffNY/1eun33XefWb///vtTa97P\n0sWwNLeHz/xEQTH8REEx/ERBMfxEQTH8REEx/ERBMfxEQYk3z11EVgH4BYBDqjoxuW0JgH8AcG4i\n+2JVfds7WaFQ0GKxmFr35q1b896zbiXd1tZm1q+66qrU2unTp81jFy1aZNavu+46s/7AAw+Y9T17\n9qTWvvzyS/PY9evXm/Xly5ebdY91HYH3PfN46yQ0Nzen1oYMGWIe21vX5S8UCigWi1LK15byzP8y\ngNt6uP1ZVZ2U/HGDT0T1xQ2/qm4DYC+5QkS9TpbX/AtE5CMRWSUil1VsRERUE+WG/3cAxgGYBOAA\ngN+mfaGIzBORoogUvbXuiKh2ygq/qrap6llV7QTwPIDJxteuVNWCqhaamprKHScRVVhZ4ReR7m+j\nzgTwcWWGQ0S14k7pFZFXAdwMoFFE9gF4HMDNIjIJgAJoBfDrKo6RiKrADb+qzurh5hfLOVlnZ6c5\nbz7L2vne2vee22+/3axbvfznn3/ePHbu3Llm3Vsj3nP11Ven1qxeNwDcdNNNZv2WW24x63fddZdZ\ntx43Ebsd7V2D8uijj5r1ZcuWpda8Pr/Xx+/o6DDrWX8ea4FX+BEFxfATBcXwEwXF8BMFxfATBcXw\nEwVV035Enz59zHbe0aNHzeMvuyx9CoE3rdZbFvyDDz4w65axY8eadW9sWbYmB+zpyl77dOfOnWbd\nWx7be1xvvvnm1No999xjHvvEE0+Y9ddff92sW1OdZ83qqYP9/+bPn2/WvcfVm65sbRHutRm9Fmmp\n+MxPFBTDTxQUw08UFMNPFBTDTxQUw08UFMNPFFRdzTvM0r8cPHiwWT9x4oRZ9/rVln379pn1devW\nmXVvm+zhw4eb9U2bNqXWvD79rbfeatZHjRpl1p9++mmzPm3atNSatwX39OnTzfrDDz9s1q1lyT//\n/HPzWG+atbX9NwCMGzfOrFtqtWw4n/mJgmL4iYJi+ImCYviJgmL4iYJi+ImCYviJgnK36K6kQqGg\n27dvT6337dvXPN7awjtLnx4AnnnmGbP+2GOPpda8udve3G9rOfNSPPTQQ6m1xx9/3DzWmlcO+P+2\n8ePHm/Vjx46l1oYOHWoe6z0u3rbrc+bMSa2tWbMm033feOONZt36ngD2kufeVvXWz3qlt+gmoosQ\nw08UFMNPFBTDTxQUw08UFMNPFBTDTxSUO59fRMYCWANgFAAFsFJVl4nICAB/ANACoBXAvapqL7yP\nbHP2v/7669RaY2Ojeax3PcPChQvN+pEjR1Jr3nz9MWPGmPU+fez/g5999tmy799bC8Dr83d2dpp1\nT5a55971EcePHzfry5cvT61Z1x8AwFtvvWXWN2/ebNbfe+89s26tVeBdY1AppTzzdwD4japOAPA3\nAOaLyAQAjwDYoqrXANiSfE5EvYQbflU9oKo7k4+PA/gUwBUAZgBYnXzZagB3VmuQRFR5F/SaX0Ra\nAPwUwHYAo1T1QFI6iK6XBUTUS5QcfhEZDGAdgIWqet4LJu16Qd3ji2oRmSciRREptre3ZxosEVVO\nSeEXkQZ0Bf8VVf1jcnObiDQn9WYAh3o6VlVXqmpBVQtNTU2VGDMRVYAbful6e/5FAJ+q6tJupY0A\nZicfzwawofLDI6Jqcaf0ishUAP8FYBeAc32fxeh63f9vAK4E8Bd0tfrS+2HomtJrtUC8lpfVlvKm\nA3tTU73ls62lnPv3728e68myNbnHa9V5j3lHR4dZ96afWvfvLd3tndsbu1e3vPzyy2Z9xIgRZt1r\nFT733HOpNe8xtTI7ZcoU7Nixo6R+utvnV9U/AUi7M3vRdyKqW7zCjygohp8oKIafKCiGnygohp8o\nKIafKKiab9Ft9V5Pnz5tHuv18i1eH9+b2pplKvKpU6fMutfHz7JlszfurNtB9+tX/o+Q18f3rkHJ\n0sf/6quvzPqDDz5o1r3pxHfccceFDul73uOS5THvjs/8REEx/ERBMfxEQTH8REEx/ERBMfxEQTH8\nREHVtM+vquZcZW9+dxZZ5p0Dfs/Z4i3F7G1F7S1hbS1pPmzYMPNYr4/vXXuR5XvmXVuR9efBWoNh\n9OjR5rHez8uQIUPMureOwuHDh1NrI0eONI+1XMj1KHzmJwqK4ScKiuEnCorhJwqK4ScKiuEnCorh\nJwqqpn1+EUFDQ0MtT/m9rOfNspaAx+vje7xefhbVvPaimvcNZNtPIevPi3fdSJZefqXwmZ8oKIaf\nKCiGnygohp8oKIafKCiGnygohp8oKDf8IjJWRP5TRP4sIp+IyD8mty8Rkf0i8mHyZ3r1h0tElVLK\nRT4dAH6jqjtFZAiAHSLyTlJ7VlWfqd7wiKha3PCr6gEAB5KPj4vIpwCuqPbAiKi6Lug1v4i0APgp\ngO3JTQtE5CMRWSUiPe45JSLzRKQoIsX29vZMgyWiyik5/CIyGMA6AAtV9RiA3wEYB2ASun4z+G1P\nx6nqSlUtqGqhqampAkMmokooKfwi0oCu4L+iqn8EAFVtU9WzqtoJ4HkAk6s3TCKqtFLe7RcALwL4\nVFWXdru9uduXzQTwceWHR0TVUsq7/TcC+HsAu0Tkw+S2xQBmicgkAAqgFcCvqzJCIqqKUt7t/xOA\nnhYDf7vywyGiWuEVfkRBMfxEQTH8REEx/ERBMfxEQTH8REEx/ERBMfxEQTH8REEx/ERBMfxEQTH8\nREEx/ERBMfxEQYmq1u5kIu0A/tLtpkYAh2s2gAtTr2Or13EBHFu5Kjm2q1S1pPXyahr+H51cpKiq\nhdwGYKjXsdXruACOrVx5jY2/9hMFxfATBZV3+FfmfH5LvY6tXscFcGzlymVsub7mJ6L85P3MT0Q5\nySX8InKbiHwmIntE5JE8xpBGRFpFZFey83Ax57GsEpFDIvJxt9tGiMg7IrI7+bvHbdJyGltd7Nxs\n7Cyd62NXbzte1/zXfhHpC+B/AfwMwD4A7wOYpap/rulAUohIK4CCqubeExaRmwCcALBGVScmt/0z\ngCOq+lTyH+dlqvpPdTK2JQBO5L1zc7KhTHP3naUB3AngQeT42Bnjuhc5PG55PPNPBrBHVb9Q1TMA\n1gKYkcM46p6qbgNw5Ac3zwCwOvl4Nbp+eGouZWx1QVUPqOrO5OPjAM7tLJ3rY2eMKxd5hP8KAHu7\nfb4P9bXltwLYLCI7RGRe3oPpwahk23QAOAhgVJ6D6YG7c3Mt/WBn6bp57MrZ8brS+Ibfj01V1UkA\nbgcwP/n1ti5p12u2emrXlLRzc630sLP09/J87Mrd8brS8gj/fgBju30+JrmtLqjq/uTvQwDWo/52\nH247t0lq8vehnMfzvXraubmnnaVRB49dPe14nUf43wdwjYj8RET6A/glgI05jONHRGRQ8kYMRGQQ\ngJ+j/nYf3ghgdvLxbAAbchzLeepl5+a0naWR82NXdzteq2rN/wCYjq53/D8H8GgeY0gZ1zgA/538\n+STvsQF4FV2/Bn6HrvdGfgXgrwBsAbAbwGYAI+pobL8HsAvAR+gKWnNOY5uKrl/pPwLwYfJnet6P\nnTGuXB43XuFHFBTf8CMKiuEnCorhJwqK4ScKiuEnCorhJwqK4ScKiuEnCur/APLdT05RF5ZqAAAA\nAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "\n", "\n", "record_iterator = tf.python_io.tf_record_iterator(path='mnist_train.tfrecords')\n", "\n", "for r in record_iterator:\n", " example = tf.train.Example()\n", " example.ParseFromString(r)\n", " \n", " label = example.features.feature['label'].int64_list.value[0]\n", " print('Label:', label)\n", " img = np.array(example.features.feature['image'].int64_list.value)\n", " img = img.reshape((28, 28))\n", " plt.imshow(img, cmap='binary')\n", " plt.show\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far so good, the image above looks okay. In the next secction, we will introduce a slightly different approach for loading the images, namely, the [`TFRecordReader`](https://www.tensorflow.org/api_docs/python/tf/TFRecordReader), which we need to load images inside a TensorFlow graph." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Loading images via the TFRecordReader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Roughly speaking, we can regard the [`TFRecordReader`](https://www.tensorflow.org/api_docs/python/tf/TFRecordReader) as a class that let's us load images \"symbolically\" inside a TensorFlow graph. A `TFRecordReader` uses the state in the graph to remember the location of a `.tfrecord` file that it reads and lets us iterate over training examples and batches after initializing the graph as we will see later.\n", "\n", "To see how it works, let's start with a simple function that reads one image at a time:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def read_one_image(tfrecords_queue, normalize=True):\n", "\n", " reader = tf.TFRecordReader()\n", " key, value = reader.read(tfrecords_queue)\n", " features = tf.parse_single_example(value,\n", " features={'label': tf.FixedLenFeature([], tf.int64),\n", " 'image': tf.FixedLenFeature([784], tf.int64)})\n", " label = tf.cast(features['label'], tf.int32)\n", " image = tf.cast(features['image'], tf.float32)\n", " onehot_label = tf.one_hot(indices=label, depth=10)\n", " \n", " if normalize:\n", " # normalize to [0, 1] range\n", " image = image / 255.\n", " \n", " return onehot_label, image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use this `read_one_image` function to fetch images in a TensorFlow session, we will make use of queue runners as illustrated in the following example:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label: [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] \n", "Image dimensions: (784,)\n" ] } ], "source": [ "g = tf.Graph()\n", "with g.as_default():\n", " \n", " queue = tf.train.string_input_producer(['mnist_train.tfrecords'], \n", " num_epochs=10)\n", " label, image = read_one_image(queue)\n", "\n", "\n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.local_variables_initializer())\n", " sess.run(tf.global_variables_initializer())\n", " \n", " coord = tf.train.Coordinator()\n", " threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n", " \n", " for i in range(10):\n", " one_label, one_image = sess.run([label, image])\n", " \n", " print('Label:', one_label, '\\nImage dimensions:', one_image.shape)\n", " \n", " coord.request_stop()\n", " coord.join(threads)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The `tf.train.string_input_producer` produces a filename queue that we iterate over in the session. Note that we need to call `sess.run(tf.local_variables_initializer())` if we define a fixed number of `num_epochs` in `tf.train.string_input_producer`. Alternatively, `num_epochs` can be set to `None` to iterate \"infinitely.\" \n", "\n", "- The `tf.train.start_queue_runners` function uses a queue runner that uses a separate thread to load the filenames from the `queue` that we defined in the graph without blocking the reader." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, we rarely (want to) train neural networks with one datapoint at a time but use minibatches instead. TensorFlow also has some really convenient utility functions to do the batching conveniently. In the following code example, we will use the [`tf.train.shuffle_batch`](https://www.tensorflow.org/api_docs/python/tf/train/shuffle_batch) function to load the images and labels in batches of size 64:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batch size: 64\n" ] } ], "source": [ "g = tf.Graph()\n", "with g.as_default():\n", " \n", " queue = tf.train.string_input_producer(['mnist_train.tfrecords'], \n", " num_epochs=10)\n", " label, image = read_one_image(queue)\n", " \n", " \n", " label_batch, image_batch = tf.train.shuffle_batch([label, image], \n", " batch_size=64,\n", " capacity=5000,\n", " min_after_dequeue=2000,\n", " num_threads=8,\n", " seed=123)\n", "\n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.local_variables_initializer())\n", " sess.run(tf.global_variables_initializer())\n", " \n", " coord = tf.train.Coordinator()\n", " threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n", " \n", " for i in range(10):\n", " many_labels, many_images = sess.run([label_batch, image_batch])\n", " \n", " print('Batch size:', many_labels.shape[0])\n", " \n", " coord.request_stop()\n", " coord.join(threads)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The other relevant arguments we provided to `tf.train.shuffle_batch` are described below:\n", "\n", "- `capacity`: An integer that defines the maximum number of elements in the queue.\n", "- `min_after_dequeue`: The minimum number elements in the queue after a dequeue, which is used to ensure that a minimum number of data points have been loaded for shuffling.\n", "- `num_threads`: The number of threads for enqueuing.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Use queue runners to train a neural network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this section, we will take the concepts that were introduced in the previous sections and train a multilayer perceptron from the `'mnist_train.tfrecords'` file:\n", " " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | AvgCost: 0.007\n", "Epoch: 002 | AvgCost: 0.469\n", "Epoch: 003 | AvgCost: 0.240\n", "Epoch: 004 | AvgCost: 0.183\n", "Epoch: 005 | AvgCost: 0.151\n", "Epoch: 006 | AvgCost: 0.128\n", "Epoch: 007 | AvgCost: 0.110\n", "Epoch: 008 | AvgCost: 0.099\n", "Epoch: 009 | AvgCost: 0.087\n", "Epoch: 010 | AvgCost: 0.078\n", "Epoch: 011 | AvgCost: 0.070\n", "Epoch: 012 | AvgCost: 0.063\n", "Epoch: 013 | AvgCost: 0.058\n", "Epoch: 014 | AvgCost: 0.051\n", "Epoch: 015 | AvgCost: 0.047\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "learning_rate = 0.1\n", "batch_size = 128\n", "n_epochs = 15\n", "n_iter = n_epochs * (45000 // batch_size)\n", "\n", "# Architecture\n", "n_hidden_1 = 128\n", "n_hidden_2 = 256\n", "height, width = 28, 28\n", "n_classes = 10\n", "\n", "\n", "\n", "##########################\n", "### GRAPH DEFINITION\n", "##########################\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " \n", " tf.set_random_seed(123)\n", "\n", " # Input data\n", " queue = tf.train.string_input_producer(['mnist_train.tfrecords'], \n", " num_epochs=None)\n", " label, image = read_one_image(queue)\n", " \n", " label_batch, image_batch = tf.train.shuffle_batch([label, image], \n", " batch_size=batch_size,\n", " seed=123,\n", " num_threads=8,\n", " capacity=5000,\n", " min_after_dequeue=2000)\n", " \n", " tf_images = tf.placeholder_with_default(image_batch,\n", " shape=[None, 784], \n", " name='images')\n", " tf_labels = tf.placeholder_with_default(label_batch, \n", " shape=[None, 10], \n", " name='labels')\n", "\n", " # Model parameters\n", " weights = {\n", " 'h1': tf.Variable(tf.truncated_normal([height*width, n_hidden_1], stddev=0.1)),\n", " 'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev=0.1)),\n", " 'out': tf.Variable(tf.truncated_normal([n_hidden_2, n_classes], stddev=0.1))\n", " }\n", " biases = {\n", " 'b1': tf.Variable(tf.zeros([n_hidden_1])),\n", " 'b2': tf.Variable(tf.zeros([n_hidden_2])),\n", " 'out': tf.Variable(tf.zeros([n_classes]))\n", " }\n", "\n", " # Multilayer perceptron\n", " layer_1 = tf.add(tf.matmul(tf_images, weights['h1']), biases['b1'])\n", " layer_1 = tf.nn.relu(layer_1)\n", " layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])\n", " layer_2 = tf.nn.relu(layer_2)\n", " out_layer = tf.matmul(layer_2, weights['out']) + biases['out']\n", "\n", " # Loss and optimizer\n", " loss = tf.nn.softmax_cross_entropy_with_logits(logits=out_layer, labels=tf_labels)\n", " cost = tf.reduce_mean(loss, name='cost')\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n", " train = optimizer.minimize(cost, name='train')\n", "\n", " # Prediction\n", " prediction = tf.argmax(out_layer, 1, name='prediction')\n", " correct_prediction = tf.equal(tf.argmax(label_batch, 1), tf.argmax(out_layer, 1))\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')\n", " \n", " \n", " \n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", " saver0 = tf.train.Saver()\n", " \n", " coord = tf.train.Coordinator()\n", " threads = tf.train.start_queue_runners(sess=sess, coord=coord)\n", " \n", " avg_cost = 0.\n", " iter_per_epoch = n_iter // n_epochs\n", " epoch = 0\n", "\n", " for i in range(n_iter):\n", " _, cost = sess.run(['train', 'cost:0'])\n", " avg_cost += cost\n", " \n", " if not i % iter_per_epoch:\n", " epoch += 1\n", " avg_cost /= iter_per_epoch\n", " print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch, avg_cost))\n", " avg_cost = 0.\n", " \n", " \n", " coord.request_stop()\n", " coord.join(threads)\n", " \n", " saver0.save(sess, save_path='./mlp')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After looking at the graph above, you probably wondered why we used [`tf.placeholder_with_default`](https://www.tensorflow.org/api_docs/python/tf/placeholder_with_default) to define the two placeholders:\n", "\n", "```python\n", "tf_images = tf.placeholder_with_default(image_batch,\n", " shape=[None, 784], \n", " name='images')\n", "tf_labels = tf.placeholder_with_default(label_batch, \n", " shape=[None, 10], \n", " name='labels')\n", "``` \n", "\n", "In the training session above, these placeholders are being ignored if we don't feed them via a session's `feed_dict`, or in other words \"[A `tf.placeholder_with_default` is a] placeholder op that passes through input when its output is not fed\" (https://www.tensorflow.org/api_docs/python/tf/placeholder_with_default).\n", "\n", "However, these placeholders are useful if we want to feed new data to the graph and make predictions after training as in a real-world application, which we will see in the next section." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Feeding new datapoints through placeholders" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To demonstrate how we can feed new data points to the network that are not part of the `mnist_train.tfrecords` file, let's use the test dataset and load the images into Python and pass it to the graph using a `feed_dict`:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./mlp\n", "Test accuracy: 97.3%\n" ] } ], "source": [ "record_iterator = tf.python_io.tf_record_iterator(path='mnist_test.tfrecords')\n", "\n", "with tf.Session() as sess:\n", " \n", " saver1 = tf.train.import_meta_graph('./mlp.meta')\n", " saver1.restore(sess, save_path='./mlp')\n", " \n", " num_correct = 0\n", " for idx, r in enumerate(record_iterator):\n", " example = tf.train.Example()\n", " example.ParseFromString(r)\n", " label = example.features.feature['label'].int64_list.value[0]\n", " image = np.array(example.features.feature['image'].int64_list.value)\n", " \n", " pred = sess.run('prediction:0', \n", " feed_dict={'images:0': image.reshape(1, 784)})\n", "\n", " num_correct += int(label == pred[0])\n", " acc = num_correct / (idx + 1) * 100\n", "\n", "print('Test accuracy: %.1f%%' % acc)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }