In [3]:
import tensorflow as tf
import numpy as np
import os.path as op
import os
import shutil
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

In [4]:
data_dir = op.expanduser("~/data/mnist")
mnist = read_data_sets(data_dir, one_hot=True)
logs_dir = '/tmp/tensorflow_logs'

Extracting /home/ogrisel/data/mnist/train-images-idx3-ubyte.gz
Extracting /home/ogrisel/data/mnist/train-labels-idx1-ubyte.gz
Extracting /home/ogrisel/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /home/ogrisel/data/mnist/t10k-labels-idx1-ubyte.gz


In [6]:
def vec_normalize(vec):
    vec_norm = tf.sqrt(tf.reduce_sum(tf.square(vec)))
    return vec / (vec_norm + 1e-7)

In [13]:
tf.reset_default_graph()
sess = tf.Session()
dtype = tf.float32
learning_rate = tf.Variable(tf.constant(0.001, dtype=dtype))


with tf.name_scope('input'):
    x = tf.placeholder(dtype=dtype, shape=[None, 784], name='x-input')
    y = tf.placeholder(dtype=dtype, shape=[None, 10], name='y-input')


with tf.name_scope('variables'):
    W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1,
                                        dtype=dtype),
                    name='W')
    tf.histogram_summary('weights', W)
    b = tf.Variable(tf.zeros(shape=(10,), dtype=dtype), name='b')
    tf.histogram_summary('biases', b)
    slow_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
    fast_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
    dir_similarity = tf.matmul(tf.reshape(slow_direction, [1, -1]),
                               tf.reshape(fast_direction, [-1, 1]))[0, 0]
    tf.scalar_summary('dir_similarity', dir_similarity)


with tf.name_scope('model'):
    preactivations = tf.matmul(x, W) + b
    tf.histogram_summary('preactivations', preactivations)
    y_pred = tf.nn.softmax(preactivations)
    tf.histogram_summary('predicted_probabilities', y_pred)


with tf.name_scope('loss'):
    cross_entropies = tf.nn.softmax_cross_entropy_with_logits(preactivations, y)
    cross_entropy = tf.reduce_mean(cross_entropies, name='cross_entropy')
    tf.scalar_summary('cross_entropy', cross_entropy)


with tf.name_scope('accuracy'):
    with tf.name_scope('correct_prediction'):
        correct_prediction = tf.equal(tf.argmax(y, 1),
                                      tf.argmax(y_pred, 1))
    with tf.name_scope('correct_prediction'):
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))
    tf.scalar_summary('accuracy', accuracy)


with tf.name_scope('gradient_directions'):
    [gW, gb] = tf.gradients(cross_entropy, [W, b])
    gW_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)))
    g_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)) + tf.reduce_sum(tf.square(gb)))
    tf.scalar_summary('gradient norm', g_norm)
    gW_normed = tf.reshape(gW / (gW_norm + 1e-7), [-1])
    

with tf.name_scope('updates'):
    W_update = W.assign_add(-learning_rate * gW)
    b_update = b.assign_add(-learning_rate * gb)

    slow_rate = 0.05
    new_slow_dir = slow_rate * gW_normed + (1 - slow_rate) * slow_direction
    slow_dir_update = slow_direction.assign(vec_normalize(new_slow_dir))
    slow_dir_reset = slow_direction.assign(gW_normed)

    fast_rate = 0.5
    new_fast_dir = fast_rate * gW_normed + (1 - fast_rate) * fast_direction
    fast_dir_update = fast_direction.assign(vec_normalize(new_fast_dir))
    fast_dir_reset = fast_direction.assign(gW_normed)
    
    lr_up = learning_rate.assign(2 * learning_rate)
    lr_down = learning_rate.assign(0.1 * learning_rate)

In [14]:
def data_dict(train=True, batch_size=128):
    """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
    if train:
        xs, ys = mnist.train.next_batch(batch_size)
    else:
        xs, ys = mnist.test.images, mnist.test.labels
    return {x: xs.astype(np.float32), y: ys.astype(np.float32)}

In [42]:
sess.run(tf.unpack(cross_entropies, num=128), feed_dict=data_dict(train=True))

[3.7324717,
 1.8400075,
 3.6116686,
 1.7056865,
 2.0967669,
 2.3198204,
 2.0171783,
 1.7580715,
 2.3690248,
 1.6259246,
 3.3400445,
 2.0483375,
 1.8097279,
 2.1330061,
 1.8711376,
 2.6461897,
 3.4301701,
 1.9341406,
 1.4128025,
 3.1168551,
 3.7157907,
 2.9077392,
 3.325181,
 0.89536273,
 4.3979654,
 0.92203724,
 2.2007186,
 2.2937737,
 2.1817045,
 1.8752966,
 1.6373912,
 2.0365462,
 1.6608343,
 2.6484132,
 3.4957781,
 1.9901035,
 1.9084624,
 2.6680474,
 2.0449464,
 3.0129635,
 2.3355575,
 2.6466174,
 2.5199924,
 1.9306296,
 2.47153,
 2.5479219,
 2.2662649,
 2.6136937,
 1.7118273,
 1.3672709,
 3.4149623,
 2.152194,
 2.3145103,
 2.0982614,
 1.5726066,
 5.6496277,
 1.0098101,
 3.9320879,
 1.1875807,
 1.6345842,
 4.9190865,
 2.1639643,
 3.1059659,
 2.1172271,
 1.7091054,
 1.7676189,
 2.1179478,
 3.7244837,
 2.3459547,
 4.1443148,
 2.1852176,
 3.9915304,
 1.9832186,
 2.7600195,
 1.882431,
 3.7910852,
 0.71117669,
 2.0551419,
 4.1546483,
 3.7213643,
 1.8376471,
 2.0418112,
 1.8615329,
 3.679

In [22]:
def cosine_similarities(x):
    x = vec_normalize(x)
    return tf.matmul(x, tf.transpose(x))

# flat_gW = tf.reshape(gW, [-1])
[gWs, gbs] = tf.gradients(cross_entropies, [W, b])
sims = cosine_similarities(gWs)

# sess.run(tf.initialize_all_variables())
sess.run(sims, feed_dict=data_dict(train=False)).shape

(784, 784)

In [19]:
128 ** 2

16384

In [None]:
summaries = tf.merge_all_summaries()
shutil.rmtree(logs_dir)
train_writer = tf.train.SummaryWriter(logs_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(logs_dir + '/test')

sess.run(tf.initialize_all_variables())

In [94]:
last_lr_change = 0
cool_down = 100

for i in range(10000):
    if i % 100 == 0:
        # Evaluate on test set
        test_summaries, test_acc, test_dir_similarity, lr = sess.run(
            [summaries, accuracy, dir_similarity, learning_rate],
            feed_dict=data_dict(train=False))
        test_writer.add_summary(test_summaries, i)
        print("Accuracy on test: %0.3f, gdir similarity: %0.3f, lr: %f"
              % (test_acc, test_dir_similarity, lr))
        if lr < 1e-10:
            print('Converged!')
            break
    else:
        # Evaluate on train mini_batch
        train_summaries, _, _, _, _, train_dir_similarity = sess.run(
            [summaries, W_update, b_update, slow_dir_update,
             fast_dir_update, dir_similarity],
            feed_dict=data_dict(train=True))
        train_writer.add_summary(train_summaries, i)
#         if i - last_lr_change > cool_down:
#             if train_dir_similarity > 0.5:
#                 print("up")
#                 sess.run([lr_up, slow_dir_reset, fast_dir_reset],
#                          feed_dict=data_dict(train=True))
#                 last_lr_change = i
#                 cool_down = 1000
#             elif train_dir_similarity < 0:
#                 print('down')
#                 sess.run([lr_down, slow_dir_reset, fast_dir_reset],
#                          feed_dict=data_dict(train=True))
#                 last_lr_change = i
#                 cool_down = 1000

Accuracy on test: 0.111, gdir similarity: 0.000, lr: 0.001000
Accuracy on test: 0.138, gdir similarity: 0.960, lr: 0.001000
Accuracy on test: 0.166, gdir similarity: 0.956, lr: 0.001000
Accuracy on test: 0.216, gdir similarity: 0.938, lr: 0.001000
Accuracy on test: 0.273, gdir similarity: 0.907, lr: 0.001000
Accuracy on test: 0.329, gdir similarity: 0.945, lr: 0.001000
Accuracy on test: 0.379, gdir similarity: 0.930, lr: 0.001000
Accuracy on test: 0.421, gdir similarity: 0.949, lr: 0.001000
Accuracy on test: 0.458, gdir similarity: 0.935, lr: 0.001000
Accuracy on test: 0.491, gdir similarity: 0.948, lr: 0.001000
Accuracy on test: 0.524, gdir similarity: 0.904, lr: 0.001000


KeyboardInterrupt: 