{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": false }, "outputs": [], "source": [ "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "# Download MNIST datasource\n", "# 6w个 28 * 28个像素的手写数字图片集\n", "# 用 [60000, 784]的张量表示 [图片索引, 图片像素点索引]\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "\n", "# `one-hot vectors`:向量中只有一个数据为 1,其余维度只能为 0\n", "# 转化为 [60000, 10]的张量表示 [图片索引, 图片表示的数值]\n", "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot = True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 28 * 28 = 784的占位符\n", "# None表示可能是任何数值\n", "x = tf.placeholder(tf.float32, [None, 784])\n", "y = tf.placeholder(tf.float32, [None, 10])\n", "z = tf.placeholder(tf.float32) # 用于 drop_out操作时的依据 (0.8: 80%的神经元在工作)\n", "lr = tf.Variable(0.001, dtype = tf.float32) # 用于不断递减的学习率,使得梯度下降到最低点时,能更好地命中" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 权重值(截断的随机正太分布) 和 偏置量 (0.1)\n", "W1 = tf.Variable(tf.truncated_normal([784, 600], stddev = 0.1))\n", "b1 = tf.Variable(tf.zeros([600]) + 0.1)\n", "L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)\n", "L1_drop = tf.nn.dropout(L1, z)\n", "\n", "# 隐藏层\n", "W2 = tf.Variable(tf.truncated_normal([600, 400], stddev = 0.1))\n", "b2 = tf.Variable(tf.zeros([400]) + 0.1)\n", "L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)\n", "L2_drop = tf.nn.dropout(L2, z)\n", "\n", "W3 = tf.Variable(tf.truncated_normal([400, 10], stddev = 0.1))\n", "b3 = tf.Variable(tf.zeros([10]) + 0.1)\n", "\n", "# softmax回归模型\n", "prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 二次 Loss Func\n", "# loss = tf.reduce_mean(tf.square(y - prediction))\n", "# 交叉熵 Loss Func\n", "# loss = tf.reduce_mean(-tf.reduce_sum(y * tf.log(prediction), reduction_indices=[1]))\n", "loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y, logits = prediction))\n", "\n", "# 梯度下降\n", "# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)\n", "train_step = tf.train.AdamOptimizer(lr).minimize(loss)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Batch: 0 Accuracy: [ 0.9573 , 0.962128 ]\n", "Batch: 10 Accuracy: [ 0.9803 , 0.994455 ]\n", "Batch: 20 Accuracy: [ 0.9795 , 0.997073 ]\n", "Batch: 30 Accuracy: [ 0.9816 , 0.997709 ]\n", "Batch: 40 Accuracy: [ 0.9828 , 0.997946 ]\n", "Batch: 50 Accuracy: [ 0.9823 , 0.998128 ]\n", "Batch: 60 Accuracy: [ 0.9828 , 0.998237 ]\n", "Batch: 70 Accuracy: [ 0.9828 , 0.998309 ]\n", "Batch: 80 Accuracy: [ 0.9831 , 0.998346 ]\n", "Batch: 90 Accuracy: [ 0.9828 , 0.9984 ]\n", "Batch: 100 Accuracy: [ 0.9831 , 0.9984 ]\n" ] } ], "source": [ "# 评估模型\n", "# 判断 一维张量 y、prediction中最大值的位置是否相等\n", "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction,1))\n", "# 准确率\n", "# 将 布尔型列表 corrent_prediction转化为 float32类型\n", "# [true, false, false, ...] => [1.0, 0., 0., ...]\n", "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", "\n", "with tf.device('/gpu:0'):\n", " with tf.Session() as sess:\n", " tf.global_variables_initializer().run()\n", "\n", " batch_size = 100\n", " batch = (int) (60000 / batch_size)\n", " # batch = mnist.train.num_examples\n", "\n", " for _ in range(101):\n", " sess.run(tf.assign(lr, 0.001 * (0.95 ** _)))\n", " for batch_step in range(batch):\n", " batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n", " sess.run(train_step, feed_dict = {x: batch_xs, y: batch_ys, z: 0.9973})\n", "\n", " if (_ % 10) == 0:\n", " test_accuracy = sess.run(accuracy, feed_dict = {x: mnist.test.images, y: mnist.test.labels, z: 1.0})\n", " train_accuracy = sess.run(accuracy, feed_dict = {x: mnist.train.images, y: mnist.train.labels, z: 1.0})\n", " print(\"Batch: \", _, \"Accuracy: [\", test_accuracy, \",\", train_accuracy, \"]\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 二次 Loss Func\n", "Batch: 0 Accuracy: 0.8394\n", "Batch: 10 Accuracy: 0.9067\n", "Batch: 20 Accuracy: 0.9142\n", "Batch: 30 Accuracy: 0.9187\n", "Batch: 40 Accuracy: 0.9199\n", "Batch: 50 Accuracy: 0.9219\n", "\n", "# 交叉熵 Loss Func\n", "Batch: 0 Accuracy: 0.8262\n", "Batch: 10 Accuracy: 0.9183\n", "Batch: 20 Accuracy: 0.9224\n", "Batch: 30 Accuracy: 0.9232\n", "Batch: 40 Accuracy: 0.9273\n", "Batch: 50 Accuracy: 0.9274\n", "\n", "# 隐藏层 + DropOut\n", "Batch: 0 Accuracy: [ 0.9176 , 0.915527 ]\n", "Batch: 10 Accuracy: [ 0.9565 , 0.963182 ]\n", "Batch: 20 Accuracy: [ 0.9669 , 0.975236 ]\n", "Batch: 30 Accuracy: [ 0.9718 , 0.982 ]\n", "Batch: 40 Accuracy: [ 0.9737 , 0.984836 ]\n", "Batch: 50 Accuracy: [ 0.9768 , 0.987036 ]\n", "\n", "# AdamOptimizer\n", "Batch: 0 Accuracy: [ 0.9573 , 0.962128 ]\n", "Batch: 10 Accuracy: [ 0.9803 , 0.994455 ]\n", "Batch: 20 Accuracy: [ 0.9795 , 0.997073 ]\n", "Batch: 30 Accuracy: [ 0.9816 , 0.997709 ]\n", "Batch: 40 Accuracy: [ 0.9828 , 0.997946 ]\n", "Batch: 50 Accuracy: [ 0.9823 , 0.998128 ]\n", "Batch: 60 Accuracy: [ 0.9828 , 0.998237 ]\n", "Batch: 70 Accuracy: [ 0.9828 , 0.998309 ]\n", "Batch: 80 Accuracy: [ 0.9831 , 0.998346 ]\n", "Batch: 90 Accuracy: [ 0.9828 , 0.9984 ]\n", "Batch: 100 Accuracy: [ 0.9831 , 0.9984 ]" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.3" } }, "nbformat": 4, "nbformat_minor": 2 }