{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter(action='ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "from time import time\n", "\n", "import tensorflow as tf\n", "tf.logging.set_verbosity(tf.logging.ERROR) # 过滤掉 Tensorflow 的 Warning 信息\n", "\n", "import tensorflow.examples.tutorials.mnist.input_data as input_data\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 数据预处理" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/train-images-idx3-ubyte.gz\n", "Extracting data/train-labels-idx1-ubyte.gz\n", "Extracting data/t10k-images-idx3-ubyte.gz\n", "Extracting data/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "mnist = input_data.read_data_sets('data/', one_hot=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train images shape: (55000, 784)\n", "train labels shape: (55000, 10)\n", "\n", "validation images shape: (5000, 784)\n", "validation labels shape: (5000, 10)\n", "\n", "test images shape: (10000, 784)\n", "test labels shape: (10000, 10)\n" ] } ], "source": [ "print('train images shape:', mnist.train.images.shape)\n", "print('train labels shape:', mnist.train.labels.shape)\n", "print()\n", "print('validation images shape:', mnist.validation.images.shape)\n", "print('validation labels shape:', mnist.validation.labels.shape)\n", "print()\n", "print('test images shape:', mnist.test.images.shape)\n", "print('test labels shape:', mnist.test.labels.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 建立模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.1 定义 layer 函数" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def layer(output_dim, input_dim, inputs, activation=None):\n", " W = tf.Variable(tf.random_normal([input_dim, output_dim]))\n", " b = tf.Variable(tf.random_normal([1, output_dim]))\n", " result = tf.matmul(inputs, W) + b\n", " \n", " if activation is None:\n", " outputs = result\n", " else:\n", " outputs = activation(result)\n", " \n", " return outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.2 建立输入层" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "x = tf.placeholder('float', [None, 784])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3 建立隐藏层" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "h1 = layer(output_dim=256, input_dim=784, inputs=x, activation=tf.nn.relu)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.4 建立输出层" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "y_predict = layer(output_dim=10, input_dim=256, inputs=h1, activation=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 定义训练方式" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.1 建立训练数据 label 真实值的 placeholder" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "y_label = tf.placeholder('float', [None, 10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2 定义损失函数" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_label))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3 定义优化器" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. 定义评估模型准确率的方式" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 4.1 计算每一项数据是否预测正确" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "correct_prediction = tf.equal(tf.argmax(y_label, axis=1), tf.argmax(y_predict, axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 4.2 计算预测正确结果的平均值" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. 开始训练" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "train_epochs = 15\n", "batch_size = 100\n", "total_batch = int(mnist.train.num_examples / batch_size)\n", "epoch_list = []\n", "loss_list = []\n", "acc_list = []" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train epoch: 01 loss: 6.757495403 acc: 0.8368\n", "train epoch: 02 loss: 4.420877934 acc: 0.8854\n", "train epoch: 03 loss: 3.375313997 acc: 0.9038\n", "train epoch: 04 loss: 2.792876720 acc: 0.91\n", "train epoch: 05 loss: 2.404265642 acc: 0.9208\n", "train epoch: 06 loss: 2.152849913 acc: 0.9242\n", "train epoch: 07 loss: 1.942529678 acc: 0.9256\n", "train epoch: 08 loss: 1.796358347 acc: 0.931\n", "train epoch: 09 loss: 1.639527202 acc: 0.9368\n", "train epoch: 10 loss: 1.564802527 acc: 0.9388\n", "train epoch: 11 loss: 1.503402591 acc: 0.9392\n", "train epoch: 12 loss: 1.392196417 acc: 0.9442\n", "train epoch: 13 loss: 1.427558780 acc: 0.9424\n", "train epoch: 14 loss: 1.377343774 acc: 0.9472\n", "train epoch: 15 loss: 1.309303880 acc: 0.9476\n", "\n", "train finished. takes 27.56218981742859 seconds\n" ] } ], "source": [ "start_time = time()\n", "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())\n", "\n", "for epoch in range(train_epochs):\n", " for i in range(total_batch):\n", " x_batch, y_batch = mnist.train.next_batch(batch_size=batch_size)\n", " sess.run(optimizer, feed_dict={x: x_batch, y_label: y_batch})\n", " \n", " loss, acc = sess.run([loss_function, accuracy], \n", " feed_dict={x: mnist.validation.images, y_label: mnist.validation.labels})\n", " epoch_list.append(epoch+1)\n", " loss_list.append(loss)\n", " acc_list.append(acc)\n", " print('train epoch:', '%02d' % (epoch + 1), 'loss:', '{:.9f}'.format(loss), 'acc:', acc)\n", "\n", "print()\n", "print('train finished. takes', time() - start_time, 'seconds')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. 以图形显示训练过程" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def show_train_history(x_values, y_values, title):\n", " plt.plot(x_values, y_values, label=title)\n", " plt.xlabel('Epoch')\n", " plt.ylabel(title)\n", " plt.legend([title], loc='upper left')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(epoch_list, acc_list, 'acc')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_train_history(epoch_list, loss_list, 'loss')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. 评估模型准确率" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy: 0.9424\n" ] } ], "source": [ "print('accuracy:', sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8. 进行预测" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.1 执行预测" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "prediction_result = sess.run(tf.argmax(y_predict, axis=1), feed_dict={x: mnist.test.images})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.2 预测结果" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prediction_result[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 8.3 定义函数以显示10项预测结" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def plot_images_labels_prediction(images, labels, predictions, idx, num=10):\n", " \"\"\"\n", " images: 数字图像数组\n", " labels: 真实值数组\n", " predictions: 预测结果数据\n", " idx: 开始显示的数据index\n", " num: 要显示的数据项数, 默认为10, 不超过25\n", " \"\"\"\n", " fig = plt.gcf()\n", " fig.set_size_inches(12, 14)\n", " if num > 25:\n", " num = 25\n", " for i in range(0, num):\n", " ax = plt.subplot(5, 5, i+1)\n", " ax.imshow(images[idx].reshape(28, 28), cmap='binary')\n", " title = 'lable=' + str(np.argmax(labels[idx]))\n", " if len(predictions) > 0:\n", " title += ',predict=' + str(predictions[idx])\n", " ax.set_title(title, fontsize=10)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " idx += 1\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_images_labels_prediction(mnist.test.images, mnist.test.labels, prediction_result, 0, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 9. 找出预测错误" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i=8 label=5 predict=6\n", "i=33 label=4 predict=6\n", "i=66 label=6 predict=2\n", "i=96 label=1 predict=3\n", "i=115 label=4 predict=8\n", "i=212 label=9 predict=3\n", "i=241 label=9 predict=8\n", "i=245 label=3 predict=6\n", "i=247 label=4 predict=6\n", "i=259 label=6 predict=0\n", "i=261 label=5 predict=3\n", "i=264 label=9 predict=8\n", "i=268 label=8 predict=9\n", "i=274 label=9 predict=3\n" ] } ], "source": [ "for i in range(300):\n", " label = np.argmax(mnist.test.labels[i])\n", " predict = prediction_result[i]\n", " if predict != label:\n", " print('i=' + str(i), 'label=' + str(label), 'predict=' + str(predict))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "sess.close()" ] } ], "metadata": { "kernelspec": { "display_name": "tensorflow-keras-practice", "language": "python", "name": "tensorflow-keras-practice" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }