{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import os.path as op\n", "import os\n", "import shutil\n", "from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting /home/ogrisel/data/mnist/train-images-idx3-ubyte.gz\n", "Extracting /home/ogrisel/data/mnist/train-labels-idx1-ubyte.gz\n", "Extracting /home/ogrisel/data/mnist/t10k-images-idx3-ubyte.gz\n", "Extracting /home/ogrisel/data/mnist/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "data_dir = op.expanduser(\"~/data/mnist\")\n", "mnist = read_data_sets(data_dir, one_hot=True)\n", "logs_dir = '/tmp/tensorflow_logs'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def vec_normalize(vec):\n", " vec_norm = tf.sqrt(tf.reduce_sum(tf.square(vec)))\n", " return vec / (vec_norm + 1e-7)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tf.reset_default_graph()\n", "sess = tf.Session()\n", "dtype = tf.float32\n", "learning_rate = tf.Variable(tf.constant(0.001, dtype=dtype))\n", "\n", "\n", "with tf.name_scope('input'):\n", " x = tf.placeholder(dtype=dtype, shape=[None, 784], name='x-input')\n", " y = tf.placeholder(dtype=dtype, shape=[None, 10], name='y-input')\n", "\n", "\n", "with tf.name_scope('variables'):\n", " W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1,\n", " dtype=dtype),\n", " name='W')\n", " tf.histogram_summary('weights', W)\n", " b = tf.Variable(tf.zeros(shape=(10,), dtype=dtype), name='b')\n", " tf.histogram_summary('biases', b)\n", " slow_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))\n", " fast_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))\n", " dir_similarity = tf.matmul(tf.reshape(slow_direction, [1, -1]),\n", " tf.reshape(fast_direction, [-1, 1]))[0, 0]\n", " tf.scalar_summary('dir_similarity', dir_similarity)\n", "\n", "\n", "with tf.name_scope('model'):\n", " preactivations = tf.matmul(x, W) + b\n", " tf.histogram_summary('preactivations', preactivations)\n", " y_pred = tf.nn.softmax(preactivations)\n", " tf.histogram_summary('predicted_probabilities', y_pred)\n", "\n", "\n", "with tf.name_scope('loss'):\n", " cross_entropies = tf.nn.softmax_cross_entropy_with_logits(preactivations, y)\n", " cross_entropy = tf.reduce_mean(cross_entropies, name='cross_entropy')\n", " tf.scalar_summary('cross_entropy', cross_entropy)\n", "\n", "\n", "with tf.name_scope('accuracy'):\n", " with tf.name_scope('correct_prediction'):\n", " correct_prediction = tf.equal(tf.argmax(y, 1),\n", " tf.argmax(y_pred, 1))\n", " with tf.name_scope('correct_prediction'):\n", " accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))\n", " tf.scalar_summary('accuracy', accuracy)\n", "\n", "\n", "with tf.name_scope('gradient_directions'):\n", " [gW, gb] = tf.gradients(cross_entropy, [W, b])\n", " gW_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)))\n", " g_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)) + tf.reduce_sum(tf.square(gb)))\n", " tf.scalar_summary('gradient norm', g_norm)\n", " gW_normed = tf.reshape(gW / (gW_norm + 1e-7), [-1])\n", " \n", "\n", "with tf.name_scope('updates'):\n", " W_update = W.assign_add(-learning_rate * gW)\n", " b_update = b.assign_add(-learning_rate * gb)\n", "\n", " slow_rate = 0.05\n", " new_slow_dir = slow_rate * gW_normed + (1 - slow_rate) * slow_direction\n", " slow_dir_update = slow_direction.assign(vec_normalize(new_slow_dir))\n", " slow_dir_reset = slow_direction.assign(gW_normed)\n", "\n", " fast_rate = 0.5\n", " new_fast_dir = fast_rate * gW_normed + (1 - fast_rate) * fast_direction\n", " fast_dir_update = fast_direction.assign(vec_normalize(new_fast_dir))\n", " fast_dir_reset = fast_direction.assign(gW_normed)\n", " \n", " lr_up = learning_rate.assign(2 * learning_rate)\n", " lr_down = learning_rate.assign(0.1 * learning_rate)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def data_dict(train=True, batch_size=128):\n", " \"\"\"Make a TensorFlow feed_dict: maps data onto Tensor placeholders.\"\"\"\n", " if train:\n", " xs, ys = mnist.train.next_batch(batch_size)\n", " else:\n", " xs, ys = mnist.test.images, mnist.test.labels\n", " return {x: xs.astype(np.float32), y: ys.astype(np.float32)}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 42, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[3.7324717,\n", " 1.8400075,\n", " 3.6116686,\n", " 1.7056865,\n", " 2.0967669,\n", " 2.3198204,\n", " 2.0171783,\n", " 1.7580715,\n", " 2.3690248,\n", " 1.6259246,\n", " 3.3400445,\n", " 2.0483375,\n", " 1.8097279,\n", " 2.1330061,\n", " 1.8711376,\n", " 2.6461897,\n", " 3.4301701,\n", " 1.9341406,\n", " 1.4128025,\n", " 3.1168551,\n", " 3.7157907,\n", " 2.9077392,\n", " 3.325181,\n", " 0.89536273,\n", " 4.3979654,\n", " 0.92203724,\n", " 2.2007186,\n", " 2.2937737,\n", " 2.1817045,\n", " 1.8752966,\n", " 1.6373912,\n", " 2.0365462,\n", " 1.6608343,\n", " 2.6484132,\n", " 3.4957781,\n", " 1.9901035,\n", " 1.9084624,\n", " 2.6680474,\n", " 2.0449464,\n", " 3.0129635,\n", " 2.3355575,\n", " 2.6466174,\n", " 2.5199924,\n", " 1.9306296,\n", " 2.47153,\n", " 2.5479219,\n", " 2.2662649,\n", " 2.6136937,\n", " 1.7118273,\n", " 1.3672709,\n", " 3.4149623,\n", " 2.152194,\n", " 2.3145103,\n", " 2.0982614,\n", " 1.5726066,\n", " 5.6496277,\n", " 1.0098101,\n", " 3.9320879,\n", " 1.1875807,\n", " 1.6345842,\n", " 4.9190865,\n", " 2.1639643,\n", " 3.1059659,\n", " 2.1172271,\n", " 1.7091054,\n", " 1.7676189,\n", " 2.1179478,\n", " 3.7244837,\n", " 2.3459547,\n", " 4.1443148,\n", " 2.1852176,\n", " 3.9915304,\n", " 1.9832186,\n", " 2.7600195,\n", " 1.882431,\n", " 3.7910852,\n", " 0.71117669,\n", " 2.0551419,\n", " 4.1546483,\n", " 3.7213643,\n", " 1.8376471,\n", " 2.0418112,\n", " 1.8615329,\n", " 3.6794746,\n", " 1.9143579,\n", " 2.3036268,\n", " 2.2760646,\n", " 2.079437,\n", " 1.5919703,\n", " 2.2989087,\n", " 1.7845988,\n", " 3.588326,\n", " 3.0593004,\n", " 4.4371295,\n", " 4.3214245,\n", " 3.2552154,\n", " 1.5259778,\n", " 1.9014853,\n", " 2.7919619,\n", " 2.8022759,\n", " 2.0109711,\n", " 2.3015294,\n", " 2.1909297,\n", " 2.7841287,\n", " 3.3960891,\n", " 2.7723904,\n", " 2.6843922,\n", " 2.5696578,\n", " 1.873759,\n", " 4.0515051,\n", " 2.7884703,\n", " 3.1664705,\n", " 2.413384,\n", " 2.5950832,\n", " 1.4605098,\n", " 1.8590325,\n", " 3.5861123,\n", " 1.8300626,\n", " 3.8167338,\n", " 2.6522479,\n", " 4.1535492,\n", " 2.3777215,\n", " 1.7563944,\n", " 2.1654656,\n", " 2.0148871,\n", " 2.0356688,\n", " 1.4784111,\n", " 4.0939426]" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sess.run(tf.unpack(cross_entropies, num=128), feed_dict=data_dict(train=True))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(784, 784)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def cosine_similarities(x):\n", " x = vec_normalize(x)\n", " return tf.matmul(x, tf.transpose(x))\n", "\n", "# flat_gW = tf.reshape(gW, [-1])\n", "[gWs, gbs] = tf.gradients(cross_entropies, [W, b])\n", "sims = cosine_similarities(gWs)\n", "\n", "# sess.run(tf.initialize_all_variables())\n", "sess.run(sims, feed_dict=data_dict(train=False)).shape" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "16384" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "128 ** 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "summaries = tf.merge_all_summaries()\n", "shutil.rmtree(logs_dir)\n", "train_writer = tf.train.SummaryWriter(logs_dir + '/train', sess.graph)\n", "test_writer = tf.train.SummaryWriter(logs_dir + '/test')\n", "\n", "sess.run(tf.initialize_all_variables())" ] }, { "cell_type": "code", "execution_count": 94, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on test: 0.111, gdir similarity: 0.000, lr: 0.001000\n", "Accuracy on test: 0.138, gdir similarity: 0.960, lr: 0.001000\n", "Accuracy on test: 0.166, gdir similarity: 0.956, lr: 0.001000\n", "Accuracy on test: 0.216, gdir similarity: 0.938, lr: 0.001000\n", "Accuracy on test: 0.273, gdir similarity: 0.907, lr: 0.001000\n", "Accuracy on test: 0.329, gdir similarity: 0.945, lr: 0.001000\n", "Accuracy on test: 0.379, gdir similarity: 0.930, lr: 0.001000\n", "Accuracy on test: 0.421, gdir similarity: 0.949, lr: 0.001000\n", "Accuracy on test: 0.458, gdir similarity: 0.935, lr: 0.001000\n", "Accuracy on test: 0.491, gdir similarity: 0.948, lr: 0.001000\n", "Accuracy on test: 0.524, gdir similarity: 0.904, lr: 0.001000\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 29\u001b[0m fast_dir_update, dir_similarity],\n\u001b[0;32m 30\u001b[0m feed_dict=data_dict(train=True))\n\u001b[1;32m---> 31\u001b[1;33m \u001b[0mtrain_writer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_summaries\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 32\u001b[0m \u001b[1;31m# if i - last_lr_change > cool_down:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[1;31m# if train_dir_similarity > 0.5:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/tensorflow/python/training/summary_io.py\u001b[0m in \u001b[0;36madd_summary\u001b[1;34m(self, summary, global_step)\u001b[0m\n\u001b[0;32m 133\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbytes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 134\u001b[0m \u001b[0msumm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msummary_pb2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSummary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 135\u001b[1;33m \u001b[0msumm\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mParseFromString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 136\u001b[0m \u001b[0msummary\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msumm\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[0mevent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mevent_pb2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwall_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msummary\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/message.py\u001b[0m in \u001b[0;36mParseFromString\u001b[1;34m(self, serialized)\u001b[0m\n\u001b[0;32m 183\u001b[0m \"\"\"\n\u001b[0;32m 184\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mClear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 185\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMergeFromString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 186\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 187\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mSerializeToString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mMergeFromString\u001b[1;34m(self, serialized)\u001b[0m\n\u001b[0;32m 1089\u001b[0m \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1090\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1091\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlength\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mlength\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1092\u001b[0m \u001b[1;31m# The only reason _InternalParse would return early is if it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1093\u001b[0m \u001b[1;31m# encountered an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m 1125\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1128\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1129\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeRepeatedField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m 610\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0m_DecodeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Truncated message.'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 611\u001b[0m \u001b[1;31m# Read sub-message.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 612\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 613\u001b[0m \u001b[1;31m# The only reason _InternalParse would return early is if it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 614\u001b[0m \u001b[1;31m# encountered an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m 1125\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1128\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1129\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m 631\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0m_DecodeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Truncated message.'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 632\u001b[0m \u001b[1;31m# Read sub-message.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 633\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 634\u001b[0m \u001b[1;31m# The only reason _InternalParse would return early is if it encountered\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 635\u001b[0m \u001b[1;31m# an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m 1125\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1128\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1129\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m 237\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 238\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mDecodeField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 239\u001b[1;33m \u001b[1;33m(\u001b[0m\u001b[0mfield_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdecode_value\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 240\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m>\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 241\u001b[0m \u001b[1;32mdel\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;31m# Discard corrupt value.\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mInnerDecode\u001b[1;34m(buffer, pos)\u001b[0m\n\u001b[0;32m 338\u001b[0m \u001b[1;31m# bit set, it's not a number. In Python 2.4, struct.unpack will treat it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 339\u001b[0m \u001b[1;31m# as inf or -inf. To avoid that, we treat it specially.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 340\u001b[1;33m if ((double_bytes[7:8] in b'\\x7F\\xFF')\n\u001b[0m\u001b[0;32m 341\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mdouble_bytes\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m6\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;34mb'\\xF0'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 342\u001b[0m and (double_bytes[0:7] != b'\\x00\\x00\\x00\\x00\\x00\\x00\\xF0')):\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "last_lr_change = 0\n", "cool_down = 100\n", "\n", "for i in range(10000):\n", " if i % 100 == 0:\n", " # Evaluate on test set\n", " test_summaries, test_acc, test_dir_similarity, lr = sess.run(\n", " [summaries, accuracy, dir_similarity, learning_rate],\n", " feed_dict=data_dict(train=False))\n", " test_writer.add_summary(test_summaries, i)\n", " print(\"Accuracy on test: %0.3f, gdir similarity: %0.3f, lr: %f\"\n", " % (test_acc, test_dir_similarity, lr))\n", " if lr < 1e-10:\n", " print('Converged!')\n", " break\n", " else:\n", " # Evaluate on train mini_batch\n", " train_summaries, _, _, _, _, train_dir_similarity = sess.run(\n", " [summaries, W_update, b_update, slow_dir_update,\n", " fast_dir_update, dir_similarity],\n", " feed_dict=data_dict(train=True))\n", " train_writer.add_summary(train_summaries, i)\n", "# if i - last_lr_change > cool_down:\n", "# if train_dir_similarity > 0.5:\n", "# print(\"up\")\n", "# sess.run([lr_up, slow_dir_reset, fast_dir_reset],\n", "# feed_dict=data_dict(train=True))\n", "# last_lr_change = i\n", "# cool_down = 1000\n", "# elif train_dir_similarity < 0:\n", "# print('down')\n", "# sess.run([lr_down, slow_dir_reset, fast_dir_reset],\n", "# feed_dict=data_dict(train=True))\n", "# last_lr_change = i\n", "# cool_down = 1000" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }