{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.1\n", "IPython 6.0.0\n", "\n", "tensorflow 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p tensorflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Multilayer Perceptron with Dropout" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Typically, dropout is applied after the non-linear activation function (a). However, when using rectified linear units (ReLUs), it might make sense to apply dropout before the non-linear activation (b) for reasons of computational efficiency depending on the particular code implementation.\n", "\n", "> (a): Fully connected, linear activation -> ReLU -> Dropout -> ... \n", "> (b): Fully connected, linear activation -> Dropout -> ReLU -> ...\n", "\n", "Why do (a) and (b) produce the same results in case of ReLU?. Let's answer this question with a simple example starting with the following *logits* (outputs of the linear activation of the fully connected layer):\n", "\n", "> `[-1, -2, -3, 4, 5, 6]`\n", "\n", "Let's walk through scenario (a), applying the ReLU activation first. The output of the non-linear ReLU functions are as follows:\n", "\n", "> `[0, 0, 0, 4, 5, 6]`\n", "\n", "Remember, the ReLU activation function is defined as $f(x) = max(0, x)$; thus, all non-zero values will be changed to zeros. Now, applying dropout with a probability 0f 50%, let's assume that the units being deactivated are units 2, 4, and 6:\n", "\n", "\n", "> `[0*2, 0, 0*2, 0, 0*2, 0] = [0, 0, 0, 0, 10, 0]`\n", "\n", "\n", "Note that in dropout, units are deactivated randomly by default. In the preceding example, we assumed that the 2nd, 4th, and 6th unit were deactivated during the training iteration. Also, because we applied dropout with 50% dropout probability, we scaled the remaining units by a factor of 2.\n", "\n", "Now, let's take a look at scenario (b). Again, we assume a 50% dropout rate and that units 2, 4, and 6 are deactivated:\n", "\n", "> `[-1, -2, -3, 4, 5, 6] -> [-1*2, 0, -3*2, 0, 5*2, 0]`\n", "\n", "\n", "Now, if we pass this array to the ReLU function, the resulting array will look exactly like the one in scenario (a):\n", "\n", "\n", "> `[-2, 0, -6, 0, 10, 0] -> [0, 0, 0, 0, 10, 0]`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Low-level Implementation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./train-images-idx3-ubyte.gz\n", "Extracting ./train-labels-idx1-ubyte.gz\n", "Extracting ./t10k-images-idx3-ubyte.gz\n", "Extracting ./t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "\n", "\n", "##########################\n", "### DATASET\n", "##########################\n", "\n", "mnist = input_data.read_data_sets(\"./\", one_hot=True)\n", "\n", "\n", "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "learning_rate = 0.1\n", "training_epochs = 20\n", "batch_size = 64\n", "dropout_keep_proba = 0.5\n", "\n", "# Architecture\n", "n_hidden_1 = 128\n", "n_hidden_2 = 256\n", "n_input = 784\n", "n_classes = 10\n", "\n", "# Other\n", "random_seed = 123\n", "\n", "\n", "##########################\n", "### GRAPH DEFINITION\n", "##########################\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " \n", " tf.set_random_seed(random_seed)\n", "\n", " # Dropout settings\n", " keep_proba = tf.placeholder(tf.float32, None, name='keep_proba')\n", " \n", " # Input data\n", " tf_x = tf.placeholder(tf.float32, [None, n_input], name='features')\n", " tf_y = tf.placeholder(tf.float32, [None, n_classes], name='targets')\n", "\n", " # Model parameters\n", " weights = {\n", " 'h1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1], stddev=0.1)),\n", " 'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2], stddev=0.1)),\n", " 'out': tf.Variable(tf.truncated_normal([n_hidden_2, n_classes], stddev=0.1))\n", " }\n", " biases = {\n", " 'b1': tf.Variable(tf.zeros([n_hidden_1])),\n", " 'b2': tf.Variable(tf.zeros([n_hidden_2])),\n", " 'out': tf.Variable(tf.zeros([n_classes]))\n", " }\n", "\n", " # Multilayer perceptron\n", " layer_1 = tf.add(tf.matmul(tf_x, weights['h1']), biases['b1'])\n", " layer_1 = tf.nn.relu(layer_1)\n", " layer_1 = tf.nn.dropout(layer_1, keep_prob=keep_proba)\n", " \n", " layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])\n", " layer_2 = tf.nn.relu(layer_2)\n", " layer_2 = tf.nn.dropout(layer_2, keep_prob=keep_proba)\n", " \n", " out_layer = tf.add(tf.matmul(layer_2, weights['out']), biases['out'], name='logits')\n", "\n", " # Loss and optimizer\n", " loss = tf.nn.softmax_cross_entropy_with_logits(logits=out_layer, labels=tf_y)\n", " cost = tf.reduce_mean(loss, name='cost')\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n", " train = optimizer.minimize(cost, name='train')\n", "\n", " # Prediction\n", " correct_prediction = tf.equal(tf.argmax(tf_y, 1), tf.argmax(out_layer, 1))\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | AvgCost: 0.669 | Train/Valid ACC: 0.927/0.935\n", "Epoch: 002 | AvgCost: 0.372 | Train/Valid ACC: 0.944/0.953\n", "Epoch: 003 | AvgCost: 0.308 | Train/Valid ACC: 0.952/0.956\n", "Epoch: 004 | AvgCost: 0.271 | Train/Valid ACC: 0.962/0.961\n", "Epoch: 005 | AvgCost: 0.251 | Train/Valid ACC: 0.964/0.966\n", "Epoch: 006 | AvgCost: 0.231 | Train/Valid ACC: 0.968/0.966\n", "Epoch: 007 | AvgCost: 0.219 | Train/Valid ACC: 0.970/0.970\n", "Epoch: 008 | AvgCost: 0.204 | Train/Valid ACC: 0.972/0.971\n", "Epoch: 009 | AvgCost: 0.194 | Train/Valid ACC: 0.974/0.970\n", "Epoch: 010 | AvgCost: 0.187 | Train/Valid ACC: 0.976/0.970\n", "Epoch: 011 | AvgCost: 0.178 | Train/Valid ACC: 0.977/0.972\n", "Epoch: 012 | AvgCost: 0.175 | Train/Valid ACC: 0.978/0.972\n", "Epoch: 013 | AvgCost: 0.170 | Train/Valid ACC: 0.979/0.973\n", "Epoch: 014 | AvgCost: 0.162 | Train/Valid ACC: 0.980/0.975\n", "Epoch: 015 | AvgCost: 0.157 | Train/Valid ACC: 0.980/0.974\n", "Epoch: 016 | AvgCost: 0.153 | Train/Valid ACC: 0.982/0.976\n", "Epoch: 017 | AvgCost: 0.151 | Train/Valid ACC: 0.982/0.976\n", "Epoch: 018 | AvgCost: 0.147 | Train/Valid ACC: 0.983/0.973\n", "Epoch: 019 | AvgCost: 0.144 | Train/Valid ACC: 0.984/0.974\n", "Epoch: 020 | AvgCost: 0.143 | Train/Valid ACC: 0.985/0.975\n", "Test ACC: 0.974\n" ] } ], "source": [ "from numpy.random import seed\n", "\n", "##########################\n", "### TRAINING & EVALUATION\n", "##########################\n", "\n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", "\n", " seed(random_seed) # random seed for mnist iterator\n", " for epoch in range(training_epochs):\n", " avg_cost = 0.\n", " total_batch = mnist.train.num_examples // batch_size\n", "\n", " for i in range(total_batch):\n", " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", " _, c = sess.run(['train', 'cost:0'], feed_dict={'features:0': batch_x,\n", " 'targets:0': batch_y,\n", " 'keep_proba:0': dropout_keep_proba})\n", " avg_cost += c\n", " \n", " train_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.train.images,\n", " 'targets:0': mnist.train.labels,\n", " 'keep_proba:0': 1.0})\n", " valid_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.validation.images,\n", " 'targets:0': mnist.validation.labels,\n", " 'keep_proba:0': 1.0})\n", " \n", " print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch + 1, avg_cost / (i + 1)), end=\"\")\n", " print(\" | Train/Valid ACC: %.3f/%.3f\" % (train_acc, valid_acc))\n", " \n", " test_acc = sess.run(accuracy, feed_dict={'features:0': mnist.test.images,\n", " 'targets:0': mnist.test.labels,\n", " 'keep_proba:0': 1.0}) \n", " print('Test ACC: %.3f' % test_acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### tensorflow.layers Abstraction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Bote that we define the *dropout rate*, not the *keep probability* when we are using dropout from `tf.layers`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting ./train-images-idx3-ubyte.gz\n", "Extracting ./train-labels-idx1-ubyte.gz\n", "Extracting ./t10k-images-idx3-ubyte.gz\n", "Extracting ./t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "import tensorflow as tf\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "\n", "\n", "##########################\n", "### DATASET\n", "##########################\n", "\n", "mnist = input_data.read_data_sets(\"./\", one_hot=True)\n", "\n", "\n", "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Hyperparameters\n", "learning_rate = 0.1\n", "training_epochs = 20\n", "batch_size = 64\n", "dropout_rate = 0.5 \n", "# note that we define the dropout rate, not\n", "# the \"keep probability\" when using\n", "# dropout from tf.layers\n", "\n", "# Architecture\n", "n_hidden_1 = 128\n", "n_hidden_2 = 256\n", "n_input = 784\n", "training_epochs = 15\n", "\n", "# Other\n", "random_seed = 123\n", "\n", "\n", "##########################\n", "### GRAPH DEFINITION\n", "##########################\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " \n", " tf.set_random_seed(random_seed)\n", "\n", " # Dropout settings\n", " is_training = tf.placeholder(tf.bool, name='is_training')\n", " \n", " # Input data\n", " tf_x = tf.placeholder(tf.float32, [None, n_input], name='features')\n", " tf_y = tf.placeholder(tf.float32, [None, n_classes], name='targets')\n", "\n", " # Multilayer perceptron\n", " layer_1 = tf.layers.dense(tf_x, n_hidden_1, activation=tf.nn.relu, \n", " kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))\n", " layer_1 = tf.layers.dropout(layer_1, rate=dropout_rate, training=is_training)\n", " \n", " layer_2 = tf.layers.dense(layer_1, n_hidden_2, activation=tf.nn.relu,\n", " kernel_initializer=tf.truncated_normal_initializer(stddev=0.1))\n", " layer_2 = tf.layers.dropout(layer_1, rate=dropout_rate, training=is_training)\n", " \n", " out_layer = tf.layers.dense(layer_2, n_classes, activation=None, name='logits')\n", "\n", " # Loss and optimizer\n", " loss = tf.nn.softmax_cross_entropy_with_logits(logits=out_layer, labels=tf_y)\n", " cost = tf.reduce_mean(loss, name='cost')\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n", " train = optimizer.minimize(cost, name='train')\n", "\n", " # Prediction\n", " correct_prediction = tf.equal(tf.argmax(tf_y, 1), tf.argmax(out_layer, 1))\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001 | AvgCost: 0.814 | Train/Valid ACC: 0.917/0.925\n", "Epoch: 002 | AvgCost: 0.520 | Train/Valid ACC: 0.931/0.938\n", "Epoch: 003 | AvgCost: 0.457 | Train/Valid ACC: 0.940/0.945\n", "Epoch: 004 | AvgCost: 0.408 | Train/Valid ACC: 0.948/0.952\n", "Epoch: 005 | AvgCost: 0.393 | Train/Valid ACC: 0.952/0.956\n", "Epoch: 006 | AvgCost: 0.376 | Train/Valid ACC: 0.954/0.957\n", "Epoch: 007 | AvgCost: 0.355 | Train/Valid ACC: 0.956/0.958\n", "Epoch: 008 | AvgCost: 0.348 | Train/Valid ACC: 0.958/0.960\n", "Epoch: 009 | AvgCost: 0.338 | Train/Valid ACC: 0.961/0.964\n", "Epoch: 010 | AvgCost: 0.334 | Train/Valid ACC: 0.962/0.964\n", "Epoch: 011 | AvgCost: 0.324 | Train/Valid ACC: 0.963/0.965\n", "Epoch: 012 | AvgCost: 0.315 | Train/Valid ACC: 0.964/0.963\n", "Epoch: 013 | AvgCost: 0.310 | Train/Valid ACC: 0.965/0.965\n", "Epoch: 014 | AvgCost: 0.305 | Train/Valid ACC: 0.966/0.965\n", "Epoch: 015 | AvgCost: 0.305 | Train/Valid ACC: 0.967/0.965\n", "Test ACC: 0.961\n" ] } ], "source": [ "from numpy.random import seed\n", "\n", "##########################\n", "### TRAINING & EVALUATION\n", "##########################\n", " \n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", "\n", " seed(random_seed) # random seed for mnist iterator\n", " for epoch in range(training_epochs):\n", " avg_cost = 0.\n", " total_batch = mnist.train.num_examples // batch_size\n", "\n", " for i in range(total_batch):\n", " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", " _, c = sess.run(['train', 'cost:0'], feed_dict={'features:0': batch_x,\n", " 'targets:0': batch_y,\n", " 'is_training:0': True})\n", " avg_cost += c\n", " \n", " train_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.train.images,\n", " 'targets:0': mnist.train.labels,\n", " 'is_training:0': False})\n", " \n", " valid_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.validation.images,\n", " 'targets:0': mnist.validation.labels,\n", " 'is_training:0': False})\n", " \n", " print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch + 1, avg_cost / (i + 1)), end=\"\")\n", " print(\" | Train/Valid ACC: %.3f/%.3f\" % (train_acc, valid_acc))\n", " \n", " test_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.test.images,\n", " 'targets:0': mnist.test.labels,\n", " 'is_training:0': False})\n", " print('Test ACC: %.3f' % test_acc)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }