{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter(action='ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "from time import time\n", "\n", "import tensorflow as tf\n", "tf.logging.set_verbosity(tf.logging.ERROR) # 过滤掉 Tensorflow 的 Warning 信息\n", "\n", "import tensorflow.examples.tutorials.mnist.input_data as input_data\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 数据预处理" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/train-images-idx3-ubyte.gz\n", "Extracting data/train-labels-idx1-ubyte.gz\n", "Extracting data/t10k-images-idx3-ubyte.gz\n", "Extracting data/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "mnist = input_data.read_data_sets('data/', one_hot=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train images shape: (55000, 784)\n", "train labels shape: (55000, 10)\n", "\n", "validation images shape: (5000, 784)\n", "validation labels shape: (5000, 10)\n", "\n", "test images shape: (10000, 784)\n", "test labels shape: (10000, 10)\n" ] } ], "source": [ "print('train images shape:', mnist.train.images.shape)\n", "print('train labels shape:', mnist.train.labels.shape)\n", "print()\n", "print('validation images shape:', mnist.validation.images.shape)\n", "print('validation labels shape:', mnist.validation.labels.shape)\n", "print()\n", "print('test images shape:', mnist.test.images.shape)\n", "print('test labels shape:', mnist.test.labels.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 建立共享函数" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.1 定义 weight 函数, 用于建立权重张量" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def weight(shape):\n", " return tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='W')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.2 定义 bias 函数, 用于建立偏差张量" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def bias(shape):\n", " return tf.Variable(tf.constant(0.1, shape=shape), name='b')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3 定义 conv2d 函数, 用于进行卷积运算" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def conv2d(x, W):\n", " return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.4 定义 max_pool_2x2 函数, 用于建立池化层" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def max_pool_2x2(x):\n", " return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 建立模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.1 输入层" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('Input_Layer'):\n", " x = tf.placeholder('float', shape=[None, 784], name='x')\n", " x_image = tf.reshape(x, [-1, 28, 28, 1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2 卷积层1" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('C1_Conv'):\n", " W1 = weight([5, 5, 1, 16])\n", " b1 = bias([16])\n", " Conv1 = conv2d(x_image, W1) + b1\n", " C1_Conv = tf.nn.relu(Conv1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('C1_Pool'):\n", " C1_Pool = max_pool_2x2(C1_Conv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3 卷积层2" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('C2_Conv'):\n", " W2 = weight([5, 5, 16, 36])\n", " b2 = bias([36])\n", " Conv2 = conv2d(C1_Pool, W2) + b2\n", " C2_Conv = tf.nn.relu(Conv2)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('C2_Pool'):\n", " C2_Pool = max_pool_2x2(C2_Conv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.4 全连接层" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('D_Flat'):\n", " D_Flat = tf.reshape(C2_Pool, [-1, 1764])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('D_Hidden_Layer'):\n", " W3 = weight([1764, 128])\n", " b3 = bias([128])\n", " D_Hidden = tf.nn.relu(tf.matmul(D_Flat, W3) + b3)\n", " D_Hidden_Dropout = tf.nn.dropout(D_Hidden, keep_prob=0.8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.5 输出层" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('Output_Layer'):\n", " W4 = weight([128, 10])\n", " b4 = bias([10])\n", " y_predict = tf.nn.softmax(tf.matmul(D_Hidden_Dropout, W4) + b4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. 定义训练方式" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.1 建立训练数据 label 真实值的 placeholder" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('optimizer'):\n", " # 定义损失函数\n", " y_label = tf.placeholder('float', [None, 10], name='y_label')\n", " \n", " # 定义损失函数\n", " loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_label))\n", " \n", " # 定义优化器\n", " optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss_function)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. 定义评估模型准确率的方式" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "with tf.name_scope('evaluate_model'):\n", " correct_prediction = tf.equal(tf.argmax(y_label, axis=1), \n", " tf.argmax(y_predict, axis=1))\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. 开始训练" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "train_epochs = 15\n", "batch_size = 100\n", "total_batch = int(mnist.train.num_examples / batch_size)\n", "epoch_list = []\n", "loss_list = []\n", "acc_list = []" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train epoch: 01 loss: 1.667564631 acc: 0.8122\n", "train epoch: 02 loss: 1.628974557 acc: 0.8394\n", "train epoch: 03 loss: 1.615294456 acc: 0.8528\n", "train epoch: 04 loss: 1.520540953 acc: 0.9522\n", "train epoch: 05 loss: 1.505755424 acc: 0.961\n", "train epoch: 06 loss: 1.498941064 acc: 0.9678\n", "train epoch: 07 loss: 1.494952202 acc: 0.9696\n", "train epoch: 08 loss: 1.493835688 acc: 0.9722\n", "train epoch: 09 loss: 1.489154339 acc: 0.9734\n", "train epoch: 10 loss: 1.488664269 acc: 0.9754\n", "train epoch: 11 loss: 1.486197829 acc: 0.9776\n", "train epoch: 12 loss: 1.484781146 acc: 0.9784\n", "train epoch: 13 loss: 1.484792829 acc: 0.9784\n", "train epoch: 14 loss: 1.483059764 acc: 0.9798\n", "train epoch: 15 loss: 1.481593132 acc: 0.981\n", "\n", "train finished. takes 788.1317231655121 seconds\n" ] } ], "source": [ "start_time = time()\n", "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "for epoch in range(train_epochs):\n", " for i in range(total_batch):\n", " x_batch, y_batch = mnist.train.next_batch(batch_size=batch_size)\n", " sess.run(optimizer, feed_dict={x: x_batch, y_label: y_batch})\n", " \n", " loss, acc = sess.run([loss_function, accuracy], \n", " feed_dict={x: mnist.validation.images, y_label: mnist.validation.labels})\n", " epoch_list.append(epoch+1)\n", " loss_list.append(loss)\n", " acc_list.append(acc)\n", " print('train epoch:', '%02d' % (epoch + 1), 'loss:', '{:.9f}'.format(loss), 'acc:', acc)\n", "\n", "print()\n", "print('train finished. takes', time() - start_time, 'seconds')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. 以图形显示训练过程" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def show_train_history(x_values, y_values, title):\n", " plt.plot(x_values, y_values, label=title)\n", " plt.xlabel('Epoch')\n", " plt.ylabel(title)\n", " plt.legend([title], loc='upper left')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(epoch_list, acc_list, 'acc')" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(epoch_list, loss_list, 'loss')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8. 评估模型准确率" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy: 0.9811\n" ] } ], "source": [ "print('accuracy:', sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 9. 进行预测" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.1 执行预测" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "prediction_result = sess.run(tf.argmax(y_predict, axis=1), feed_dict={x: mnist.test.images})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.2 预测结果" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prediction_result[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 9.3 定义函数以显示10项预测结" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def plot_images_labels_prediction(images, labels, predictions, idx, num=10):\n", " \"\"\"\n", " images: 数字图像数组\n", " labels: 真实值数组\n", " predictions: 预测结果数据\n", " idx: 开始显示的数据index\n", " num: 要显示的数据项数, 默认为10, 不超过25\n", " \"\"\"\n", " fig = plt.gcf()\n", " fig.set_size_inches(12, 14)\n", " if num > 25:\n", " num = 25\n", " for i in range(0, num):\n", " ax = plt.subplot(5, 5, i+1)\n", " ax.imshow(images[idx].reshape(28, 28), cmap='binary')\n", " title = 'lable=' + str(np.argmax(labels[idx]))\n", " if len(predictions) > 0:\n", " title += ',predict=' + str(predictions[idx])\n", " ax.set_title(title, fontsize=10)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " idx += 1\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_images_labels_prediction(mnist.test.images, mnist.test.labels, prediction_result, 0, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10. 找出预测错误" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i=18 label=3 predict=8\n", "i=151 label=9 predict=8\n", "i=247 label=4 predict=6\n", "i=290 label=8 predict=4\n" ] } ], "source": [ "for i in range(300):\n", " label = np.argmax(mnist.test.labels[i])\n", " predict = prediction_result[i]\n", " if predict != label:\n", " print('i=' + str(i), 'label=' + str(label), 'predict=' + str(predict))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "def show_images_labels_predict_error(images, labels, prediction_result):\n", " fig = plt.gcf()\n", " fig.set_size_inches(12, 14)\n", " num = 0\n", " idx = 0\n", " while num < 10:\n", " label = np.argmax(mnist.test.labels[idx])\n", " predict = prediction_result[idx]\n", " if predict != label:\n", " ax = plt.subplot(5, 5, num+1)\n", " ax.imshow(np.reshape(images[idx], (28, 28)), cmap='binary')\n", " ax.set_title('idx:' + str(idx) + ',l:' + str(label) + ',p:' + str(predict), fontsize=9)\n", " num += 1\n", " idx += 1\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_images_labels_predict_error(mnist.test.images, mnist.test.labels, prediction_result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 11. 保存模型" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model saved in file: save_model/tf/model_mnist_cnn_tf\n" ] } ], "source": [ "saver = tf.train.Saver()\n", "save_path = saver.save(sess, 'save_model/tf/model_mnist_cnn_tf')\n", "print('model saved in file: %s' % save_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 12. 将计算图写入log文件, 用于在 TensorBoard 中查看" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "merged = tf.summary.merge_all()\n", "train_writer = tf.summary.FileWriter('log/graph_mnist_cnn', sess.graph)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "sess.close()" ] } ], "metadata": { "kernelspec": { "display_name": "tensorflow-keras-practice", "language": "python", "name": "tensorflow-keras-practice" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }