{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# after http://r2rt.com/recurrent-neural-networks-in-tensorflow-i.html" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Global config variables\n", "num_steps = 5 # number of truncated backprop steps ('n' in the discussion above)\n", "batch_size = 200\n", "num_classes = 2\n", "state_size = 4\n", "learning_rate = 0.1" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def gen_data(size=1000000):\n", " X = np.array(np.random.choice(2, size=(size,)))\n", " Y = []\n", " for i in range(size):\n", " threshold = 0.5\n", " if X[i-3] == 1:\n", " threshold += 0.5\n", " if X[i-8] == 1:\n", " threshold -= 0.25\n", " if np.random.rand() > threshold:\n", " Y.append(0)\n", " else:\n", " Y.append(1)\n", " return X, np.array(Y)\n", "\n", "# adapted from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/ptb/reader.py\n", "def gen_batch(raw_data, batch_size, num_steps):\n", " raw_x, raw_y = raw_data\n", " data_length = len(raw_x)\n", "\n", " # partition raw data into batches and stack them vertically in a data matrix\n", " batch_partition_length = data_length // batch_size\n", " data_x = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n", " data_y = np.zeros([batch_size, batch_partition_length], dtype=np.int32)\n", " for i in range(batch_size):\n", " data_x[i] = raw_x[batch_partition_length * i:batch_partition_length * (i + 1)]\n", " data_y[i] = raw_y[batch_partition_length * i:batch_partition_length * (i + 1)]\n", " # further divide batch partitions into num_steps for truncated backprop\n", " epoch_size = batch_partition_length // num_steps\n", "\n", " for i in range(epoch_size):\n", " x = data_x[:, i * num_steps:(i + 1) * num_steps]\n", " y = data_y[:, i * num_steps:(i + 1) * num_steps]\n", " yield (x, y)\n", "\n", "def gen_epochs(n, num_steps):\n", " for i in range(n):\n", " yield gen_batch(gen_data(), batch_size, num_steps)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\"\"\"\n", "Placeholders\n", "\"\"\"\n", "\n", "x = tf.placeholder(tf.int32, [batch_size, num_steps], name='input_placeholder')\n", "y = tf.placeholder(tf.int32, [batch_size, num_steps], name='labels_placeholder')\n", "init_state = tf.zeros([batch_size, state_size])\n", "\n", "\"\"\"\n", "RNN Inputs\n", "\"\"\"\n", "\n", "# Turn our x placeholder into a list of one-hot tensors:\n", "# rnn_inputs is a list of num_steps tensors with shape [batch_size, num_classes]\n", "x_one_hot = tf.one_hot(x, num_classes)\n", "rnn_inputs = tf.unpack(x_one_hot, axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'\\nAdding rnn_cells to graph\\n\\nThis is a simplified version of the \"rnn\" function from Tensorflow\\'s api. See:\\nhttps://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py\\n\\n\\nstate = init_state\\nrnn_outputs = []\\nfor rnn_input in rnn_inputs:\\n state = rnn_cell(rnn_input, state)\\n rnn_outputs.append(state)\\nfinal_state = rnn_outputs[-1]\\n'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "Definition of rnn_cell\n", "\n", "This is very similar to the __call__ method on Tensorflow's BasicRNNCell. See:\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py\n", "\n", "with tf.variable_scope('rnn_cell', reuse=False):\n", " W = tf.get_variable('W', [num_classes + state_size, state_size])\n", " b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n", "def rnn_cell(rnn_input, state):\n", " with tf.variable_scope('rnn_cell', reuse=True):\n", " W = tf.get_variable('W', [num_classes + state_size, state_size])\n", " b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n", " return tf.tanh(tf.matmul(tf.concat(1, [rnn_input, state]), W) + b)\n", " \n", "\"\"\"\n", "\"\"\"\n", "Adding rnn_cells to graph\n", "\n", "This is a simplified version of the \"rnn\" function from Tensorflow's api. See:\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py\n", "\n", "\n", "state = init_state\n", "rnn_outputs = []\n", "for rnn_input in rnn_inputs:\n", " state = rnn_cell(rnn_input, state)\n", " rnn_outputs.append(state)\n", "final_state = rnn_outputs[-1]\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cell = tf.nn.rnn_cell.BasicRNNCell(state_size)\n", "rnn_outputs, final_state = tf.nn.rnn(cell, rnn_inputs, initial_state=init_state)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(rnn_inputs)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#with tf.Session().as_default(): print(t.eval())\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\"\"\"\n", "Predictions, loss, training step\n", "\n", "Losses and total_loss are simlar to the \"sequence_loss_by_example\" and \"sequence_loss\"\n", "functions, respectively, from Tensorflow's api. See:\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/seq2seq.py\n", "\"\"\"\n", "\n", "#logits and predictions\n", "with tf.variable_scope('softmax'):\n", " W = tf.get_variable('W', [state_size, num_classes])\n", " b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))\n", "logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]\n", "predictions = [tf.nn.softmax(logit) for logit in logits]\n", "\n", "# Turn our y placeholder into a list labels\n", "y_as_list = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(1, num_steps, y)]\n", "\n", "#losses and train_step\n", "\"\"\"\n", "losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logit,label) for \\\n", " logit, label in zip(logits, y_as_list)]\n", "\"\"\"\n", "loss_weights = [tf.ones([batch_size]) for i in range(num_steps)]\n", "losses = tf.nn.seq2seq.sequence_loss_by_example(logits, y_as_list, loss_weights)\n", "total_loss = tf.reduce_mean(losses)\n", "train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "\"\"\"\n", "Function to train the network\n", "\"\"\"\n", "\n", "def train_network(num_epochs, num_steps, state_size=4, verbose=True):\n", " with tf.Session() as sess:\n", " sess.run(tf.initialize_all_variables())\n", " training_losses = []\n", " for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n", " training_loss = 0\n", " training_state = np.zeros((batch_size, state_size))\n", " if verbose:\n", " print(\"\\nEPOCH\", idx)\n", " for step, (X, Y) in enumerate(epoch):\n", " tr_losses, training_loss_, training_state, _ = \\\n", " sess.run([losses,\n", " total_loss,\n", " final_state,\n", " train_step],\n", " feed_dict={x:X, y:Y, init_state:training_state})\n", " training_loss += training_loss_\n", " if step % 100 == 0 and step > 0:\n", " if verbose:\n", " print(\"Average loss at step\", step,\n", " \"for last 250 steps:\", training_loss/100)\n", " training_losses.append(training_loss/100)\n", " training_loss = 0\n", "\n", " return training_losses" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "EPOCH 0\n", "Average loss at step 100 for last 250 steps: 0.604084196687\n", "Average loss at step 200 for last 250 steps: 0.562621814013\n", "Average loss at step 300 for last 250 steps: 0.553829800487\n", "Average loss at step 400 for last 250 steps: 0.530663509965\n", "Average loss at step 500 for last 250 steps: 0.521151310503\n", "Average loss at step 600 for last 250 steps: 0.520915334821\n", "Average loss at step 700 for last 250 steps: 0.518614354432\n", "Average loss at step 800 for last 250 steps: 0.519400427639\n", "Average loss at step 900 for last 250 steps: 0.517323300242\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAg0AAAFdCAYAAACAfl7+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XmYlXX9//Hnm01UDEpUXHDBhXApZcxESi1UKq9E1MRR\nc0vF1FKCb2mbS5mlKZh7ZaKiY26kfcslTTMX8ito/lTUIg0xRDTDBZXt8/vjcybGkeWcYWbuOWee\nj+s618y5z33f8zqCnNfc9+f+3JFSQpIkaWW6FB1AkiRVB0uDJEkqi6VBkiSVxdIgSZLKYmmQJEll\nsTRIkqSyWBokSVJZuhUdoKUiYm1gOPAC8G6xaSRJqio9gU2BO1NKr5W7UdWWBnJhuLboEJIkVbFD\ngOvKXbmaS8MLAJMmTWLQoEEFR2lbY8aMYfz48UXHaHO+z9ri+6wtvs/aMn36dA499FAofZaWq5pL\nw7sAgwYNYvDgwUVnaVO9e/eu+fcIvs9a4/usLb7PmlXR6X0HQkqSpLJYGiRJUlksDZIkqSyWhipQ\nX19fdIR24fusLb7P2uL7FECklIrO0CIRMRiYOnXq1M42aEWSpFUybdo06urqAOpSStPK3c4jDZIk\nqSyWBkmSVBZLgyRJKoulQZIklcXSIEmSymJpkCRJZbE0SJKkslgaJElSWSwNkiSpLJYGSZJUFkuD\nJEkqi6VBkiSVpepLw8svF51AkqTOoepLw/XXF51AkqTOoepLw803w7x5RaeQJKn2VX1pWLAAfvnL\nolNIklT7qr40fO5zMGECLFxYdBJJkmpb1ZeGQw+FWbPgxhuLTiJJUm1rUWmIiBMi4vmIeCcipkTE\nJ1ayfo+IOCsiXoiIdyPiHxFxRLN1vhQR00v7/GtEfL6cLFtuCcOHw09/Cim15N1IkqRyVFwaImIU\ncB5wGrAD8Ffgzojou4LNbgQ+AxwJbAXUA8822ecuwHXAL4DtgVuB30TE1uVkGjsWHnsM7r230ncj\nSZLK1ZIjDWOAy1NKV6eUngGOA+YDRy1r5Yj4HPBp4AsppXtTSjNTSn9JKT3cZLWvA7enlM5PKT2b\nUvo+MA04sZxAe+wBH/tYPtogSZLaRkWlISK6A3XAPY3LUkoJuBsYspzNvgg8CnwrImZFxLMRcW5E\n9GyyzpDSPpq6cwX7bJYLxo2D22+Hp54q881IkqSKVHqkoS/QFZjTbPkcoN9ythlAPtKwDbAvcBJw\nAHBxk3X6VbjPDxg1CjbcEM4/v9wtJElSJbq1w8/oAiwBDk4pvQUQEd8AboyI41NK763KzseMGUPv\n3r0B+NCH4Moroa6unuOPr1/V3JIkVb2GhgYaGhret2xeC2dFrLQ0vAosBtZrtnw9YHl3gZgNvNRY\nGEqmAwFsBMwobVvJPv9r/PjxDB48GID//Af694eXXlrZVpIkdQ719fXU17//F+lp06ZRV1dX8b4q\nOj2RUloITAWGNS6LiCg9f2g5mz0IbBARazRZNpB89GFW6fnDTfdZsmdpedn69IFjjoFLL4W3365k\nS0mStDItuXrifOCYiDgsIj4KXAasAUwEiIizI+KqJutfB7wGXBkRgyJiV+Ac4IompyYuAD4XEd+I\niIERcTp5wOVFlYY76SR44418mkKSJLWeiktDSukGYBxwJvAY8DFgeEppbmmVfkD/Juu/TT5q0Af4\nP+Aa8jwMJzVZ52HgYOBY4HFgP2BESunpSvNtsgkceGAeELl4caVbS5Kk5WnRQMiU0iXAJct57chl\nLHsOGL6Sfd4M3NySPM2NHQs77giTJ8MBB7TGHiVJUtXfe2JZ6upg992dWlqSpNZUk6UB8mRPf/kL\nPLS84ZmSJKkiNVsaPv95GDTIqaUlSWotNVsaunSBb3wDbr0Vnnuu6DSSJFW/mi0NAIceCuusA+PH\nF51EkqTqV9OloWdP+NrXYOJEmDt3patLkqQVqOnSAPDVr+a7YF56adFJJEmqbjVfGtZeG448Ei66\nCN55p+g0kiRVr5ovDQBjxsCrr8I11xSdRJKk6tUpSsMWW8DIkXlq6SVLik4jSVJ16hSlAfJkT88+\nC7/7XdFJJEmqTp2mNAwZArvs4mRPkiS1VKcpDZBvZHX//fDII0UnkSSp+nSq0jBiBGy+OZx3XtFJ\nJEmqPp2qNHTtmqeWvukmeOGFotNIklRdOlVpADjiCPjwh2HChKKTSJJUXTpdaVhjjTxL5C9/Ca+/\nXnQaSZKqR6crDQAnnggLF8LllxedRJKk6tEpS8N668Fhh8HPfgYLFhSdRpKk6tApSwPkAZGzZ0ND\nQ9FJJEmqDp22NAwaBHvvnSd7SqnoNJIkdXydtjRAnlr6ySfhrruKTiJJUsfXqUvDbrtBXZ2TPUmS\nVI5OXRoi8tGGP/wB/vrXotNIktSxderSAHDAAbDxxh5tkCRpZTp9aejWDU4+OV9FMWtW0WkkSeq4\nOn1pADj6aFhzzTxvgyRJWjZLA7DWWjB6dJ4h8o03ik4jSVLHZGko+frXYf58uOKKopNIktQxWRpK\nNtwQ6uvz3S8XLiw6jSRJHY+loYmxY2HmTLjppqKTSJLU8Vgamvj4x2HPPZ1aWpKkZbE0NDNuHEyb\nBn/6U9FJJEnqWCwNzey5J2y3XT7aIEmSlrI0NBORxzb87ncwfXrRaSRJ6jgsDctQXw8bbADnn190\nEkmSOg5LwzL06JHnbbj6anj55aLTSJLUMVgalmP06FweLr646CSSJHUMlobl6NMHvvIVuOQSePvt\notNIklQ8S8MKnHwy/Oc/MHFi0UkkSSqepWEFNt0UvvQlGD8eFi8uOo0kScWyNKzEuHEwYwbcemvR\nSSRJKpalYSV23BF23dXJniRJsjSUYdw4ePhheOihopNIklQcS0MZ9t4bBg6E884rOokkScWxNJSh\nS5c8tfTkyfD3vxedRpKkYlgayvTlL0PfvvlKCkmSOiNLQ5l69oQTT4Qrr4TXXis6jSRJ7c/SUIHj\nj4eU4NJLi04iSVL7szRUoG9fOPJIuPBCePfdotNIktS+LA0VGjMG5s6FSZOKTiJJUvtqUWmIiBMi\n4vmIeCcipkTEJ1aw7m4RsaTZY3FErNtsvZMj4pmImB8RMyPi/IhYrSX52tKWW8KIEfnyyyVLik4j\nSVL7qbg0RMQo4DzgNGAH4K/AnRHRdwWbJWBLoF/psX5K6ZUm+zwYOLu0z48CRwEHAmdVmq89jBsH\nzzwDt99edBJJktpPS440jAEuTyldnVJ6BjgOmE/+oF+RuSmlVxofzV4bAjyQUvp1SmlmSulu4Hpg\npxbka3O77AI77+zU0pKkzqWi0hAR3YE64J7GZSmlBNxN/uBf7qbA4xHxr4i4KyJ2afb6Q0Bd42mO\niBgAfAH4XSX52ktEPtpw333w6KNFp5EkqX1UeqShL9AVmNNs+RzyaYdlmQ2MBvYH9gNeBO6LiO0b\nV0gpNZBPTTwQEQuAvwH3ppR+UmG+drPvvjBggFNLS5I6j25t/QNSSs8BzzVZNCUiNief5jgcICJ2\nB75NPtXxCLAF8LOImJ1S+uGK9j9mzBh69+79vmX19fXU19e32ntYlq5d85UUJ58MP/4xbLJJm/44\nSZJapKGhgYaGhvctmzdvXov2FfnsQpkr59MT84H9U0q3NVk+EeidUhpZ5n7OAYamlIaWnt8PTEkp\nfbPJOoeQx070Ws4+BgNTp06dyuDBg8t+D63p7behf3844gg4//xCIkiSVLFp06ZRV1cHUJdSmlbu\ndhWdnkgpLQSmAsMal0VElJ5XcuPo7cmnLRqtASxqts6SJvvvkNZcM88S+YtfwH/+U3QaSZLaVkuu\nnjgfOCYiDouIjwKXkT/0JwJExNkRcVXjyhFxUkTsExGbR8Q2ETEB+AxwUZN9/hY4PiJGRcSmEbEn\ncCZwW6rkUEgBTjwRFiyAn/+86CSSJLWtisc0pJRuKM3JcCawHvA4MDylNLe0Sj+gf5NNepDnddiA\nfGrjCWBYSun+Juv8gHxk4QfAhsBc4Dbgu5Xma2/9+sGhh8IFF+TxDT16FJ1IkqS2UdGYho6kI4xp\naPTUU7DttnD11fkW2pIkdWTtMqZBy7bNNvCFL+TJnqq0g0mStFKWhlYybhw88QTcfXfRSSRJahuW\nhlay++6www5OLS1Jql2WhlbSOLX0XXflIw6SJNUaS0Mr+tKX8mRPTvQkSapFloZW1L17vuzyuuvg\npZeKTiNJUuuyNLSyo4+G1VeHCy8sOokkSa3L0tDKPvQhOPZYuOwyePPNotNIktR6LA1t4Otfzzez\n+tWvik4iSVLrsTS0gf794aCDYPx4WNT8NlySJFUpS0MbGTsW/vlPuPnmopNIktQ6LA1tZPvtYdgw\np5aWJNUOS0MbGjcOHn0U/vznopNIkrTqLA1taPjwfPdLp5aWJNUCS0MbishjG377W3jmmaLTSJK0\naiwNbay+Hvr1c2ppSVL1szS0sdVWy/M2XH01vPJK0WkkSWo5S0M7GD0aunWDiy8uOokkSS1naWgH\nH/kIfOUruTTMn190GkmSWsbS0E5OPhlefx2uuqroJJIktYyloZ1sthnsv38eELl4cdFpJEmqnKWh\nHY0dC3//e74EU5KkamNpaEef/CR8+tNO9iRJqk6WhnY2diw8+CA8/HDRSSRJqoyloZ198Yuw5ZZw\n3nlFJ5EkqTKWhnbWpUs+2jB5MsyYUXQaSZLKZ2kowGGH5bkbJkwoOokkSeWzNBRg9dXhhBPgV7+C\n114rOo0kSeWxNBTk+ONhyRK47LKik0iSVB5LQ0HWXRcOPxwuvBDee6/oNJIkrZyloUBjxuQ7X157\nbdFJJElaOUtDgQYOhH32yZM9LVlSdBpJklbM0lCwsWNh+nS4446ik0iStGKWhoJ96lOw005OLS1J\n6vgsDQWLgHHj4N57Ydq0otNIkrR8loYOYOTIfOtsp5aWJHVkloYOoFs3OPlk+PWvYebMotNIkrRs\nloYO4qijYK214IILik4iSdKyWRo6iF694KtfhV/8AubNKzqNJEkfZGnoQL72NXj33VwcJEnqaCwN\nHcj668Mhh+S7Xy5YUHQaSZLez9LQwYwdCy+9BDfcUHQSSZLez9LQwWy7LXzuc/nyy5SKTiNJ0lKW\nhg5o3Dh4/HH44x+LTiJJ0lKWhg7os5+F7bd3amlJUsdiaeiAIvLYhjvugCefLDqNJEmZpaGDGjUK\nNtwQzj+/6CSSJGWWhg6qe/c8tfSkSTBjRtFpJEmyNHRoxxwDa68NW28No0fD3/9edCJJUmdmaejA\neveGZ56BM86AW2+FgQPzaQtvoS1JKoKloYPr3RtOOQWefx4uvhgefRTq6mD4cLj3XudykCS1nxaV\nhog4ISKej4h3ImJKRHxiBevuFhFLmj0WR8S6zdbrHREXR8S/IuLdiHgmIj7Xkny1aPXV4bjj4Nln\noaEB5szJl2buvDNMngxLlhSdUJJU6youDRExCjgPOA3YAfgrcGdE9F3BZgnYEuhXeqyfUnqlyT67\nA3cDGwP7AVsBxwAvVZqv1nXrBgcdBI89BrffnsvEfvvBNtvAlVd6zwpJUttpyZGGMcDlKaWrU0rP\nAMcB84GjVrLd3JTSK42PZq99BegD7JtSmpJSmplS+nNK6f+1IF+nEJGnm77vPnjooTze4aijYPPN\nYfx4eOutohNKkmpNRaWhdESgDrincVlKKZGPEgxZ0abA46VTD3dFxC7NXv8i8DBwSUS8HBH/LyJO\njQjHXJRhyBD4zW/gqadg2DD45jdh443htNPg1VeLTidJqhWVfij3BboCc5otn0M+7bAss4HRwP7k\nUw8vAvdFxPZN1hkAfKmU5/PAmcBY4DsV5uvUtt4aJk7M8zocdliehnqTTfJ8DzNnFp1OklTtIlUw\n/D4i1iePMxiSUvpLk+U/AXZNKa3oaEPT/dwH/DOldHjp+bPAasBmpSMXRMQYYFxKacPl7GMwMHXX\nXXeld+/e73utvr6e+vr6st9XrXr1VbjoIrjwQnjjDTjkkHwUYuuti04mSWovDQ0NNDQ0vG/ZvHnz\nuP/++wHqUkplX8hfaWnoTh6/sH9K6bYmyycCvVNKI8vczznA0JTS0NLz+4AFKaW9mqzzOeB3wGop\npUXL2MdgYOrUqVMZPHhw2e+hM3rrLfjlL/PttmfNghEj8mWcO+9cdDJJUhGmTZtGXV0dVFgaKjo9\nkVJaCEwFhjUui4goPX+ogl1tTz5t0ehBYItm6wwEZi+rMKgyvXrlUxQzZuQrLJ59No+D2H33fFMs\n53qQJJWjJQMNzweOiYjDIuKjwGXAGsBEgIg4OyKualw5Ik6KiH0iYvOI2CYiJgCfAS5qss9LgY9E\nxM8iYsuI2Bs4tdk6WkU9esARR+QBk5MnwzvvwOc/DzvsANdfD4usZ5KkFai4NKSUbgDGkQcrPgZ8\nDBieUppbWqUf0L/JJj3I8zo8AdwHbAcMSynd12Sfs4DhwI7keR8mAOOBn1SaTyvXpQvsuy9MmZJn\nlezXD+rr82Wbl10G775bdEJJUkdU0ZiGjsQxDa3rscfgJz+BG2+EddbJpzO++tU8jbUkqba0y5gG\n1a7GUxTPPpuPQpx2Wp7r4ZRT4OWXi04nSeoILA16ny22yKcoXnghH2m45BLYdNN83wtvzS1JnZul\nQcu0/vrw4x/nSaFOPz0PnBw4cOl9LyRJnY+lQSvUp08+RfHCC3miqEcegcGDl973okqHxEiSWsDS\noLKsvno+XfHcc3DddTB7NnzmM0vve+GtuSWp9lkaVJFu3fLlmY8/Dr//Pay2GowcmW/NPXGit+aW\npFpmaVCLROSJof70J3jwQdhqKzjyyHxr7gkTvDW3JNUiS4NW2S67wK23wpNPwmc/C+PG5btrnn46\nvPZa0ekkSa3F0qBWs802cNVV+R4Xhx4K55yT53o4+WR48cWi00mSVpWlQa1uk03gggvgn//MRx2u\nvhoGDMinL6ZPLzqdJKmlLA1qM+usA2ecked6OOcc+MMfYOut88DJKVOKTidJqpSlQW2uVy8YMwb+\n8Q+44op8tGHIkHzJ5p13OteDJFULS4PaTY8ecNRR+dbcN98Mb7+dJ4kaPDjf92Lx4qITSpJWxNKg\ndte1K+y3H/zlL3DPPfk0Rn19HvPgUQdJ6rgsDSpMRL5E86674Npr4ZprYPz4olNJkpanW9EBJICD\nD4YnnoD/+R/YdlvYa6+iE0mSmvNIgzqMs86C4cPznTRnzCg6jSSpOUuDOoyuXfPNsNZZB0aMgDff\nLDqRJKkpS4M6lD598l0zZ86Eww/37pmS1JFYGtThDBqUB0b+5jfwwx8WnUaS1MjSoA7pi1+EM8+E\n007L5UGSVDxLgzqs73wHDjgAvvzlPCGUJKlYlgZ1WBFw5ZWw2Waw777w+utFJ5Kkzs3SoA6tVy+4\n9Vb497/zpZiLFhWdSJI6L0uDOrzNNoMbbshTTp96atFpJKnzsjSoKgwbBj/9aX5ce23RaSSpc3Ia\naVWNk06Cxx+Ho4+Gj34U6uqKTiRJnYtHGlQ1IuCyy2C77WDkSHjllaITSVLnYmlQVenZEyZPhoUL\n8+WYCxYUnUiSOg9Lg6rOhhvCzTfDlCn5lIUkqX1YGlSVdtkFLrkkn674+c+LTiNJnYMDIVW1jj4a\nHnsMTjwRttkGhg4tOpEk1TaPNKiqTZgAQ4bA/vvDrFlFp5Gk2mZpUFXr3h1uuglWWy1fUfHOO0Un\nkqTaZWlQ1VtnnXwnzKeegmOPhZSKTiRJtcnSoJqwww5wxRUwaRKMH190GkmqTQ6EVM2or88zRv7P\n/+QJoPbcs+hEklRbPNKgmvKjH8Fee8GoUTBjRtFpJKm2WBpUU7p2hYYG6NsXRoyAN98sOpEk1Q5L\ng2pOnz5w660wcyYcfjgsWVJ0IkmqDZYG1aRBg/KgyMmT4Yc/LDqNJNUGS4Nq1j77wJlnwmmn5SMP\nkqRVY2lQTfvOd/JskYcemudxkCS1nKVBNa1LF5g4ETbbDPbdF15/vehEklS9LA2qeb165Rkj//1v\nOOggWLy46ESSVJ0sDeoUBgyAX/8a7r4bTj216DSSVJ0sDeo09tgDfvpTOPdcuO66otNIUvVxGml1\nKiefnKea/spXYOBAqKsrOpEkVQ+PNKhTiYDLL8/3phg5El55pehEklQ9LA3qdHr2hFtugQUL4IAD\n8ldJ0spZGtQpbbRRLg5TpuRTFpKklWtRaYiIEyLi+Yh4JyKmRMQnVrDubhGxpNljcUSsu5z1Dyqt\nc0tLsknl2mUXuPhiuPRS+MUvik4jSR1fxQMhI2IUcB5wLPAIMAa4MyK2Sim9upzNErAV8N97DqaU\nPnA2OSI2Bc4F7q80l9QSxxyTB0aecAJsvTUMHVp0IknquFpypGEMcHlK6eqU0jPAccB84KiVbDc3\npfRK46P5ixHRBZgEfB94vgW5pBaZMAGGDMnTTc+aVXQaSeq4KioNEdEdqAPuaVyWUkrA3cCQFW0K\nPB4R/4qIuyJil2WscxowJ6V0ZSWZpFXVvTvceCP06JGvqHjnnaITSVLHVOmRhr5AV2BOs+VzgH7L\n2WY2MBrYH9gPeBG4LyK2b1whIj4FHAkcXWEeqVWsu26eavqpp2D0aEip6ESS1PG0+eROKaXngOea\nLJoSEZuTT3McHhG9gKuBY1JKFd9OaMyYMfTu3ft9y+rr66mvr1+F1OqMBg+GK66Agw+GHXaAMWOK\nTiRJq66hoYGGhob3LZs3b16L9hWpgl+pSqcn5gP7p5Rua7J8ItA7pTSyzP2cAwxNKQ2NiI8D04DF\n5NMYsPQIyGJgYErpA2McImIwMHXq1KkMHjy47Pcgrcy3vpWnm77jDthzz6LTSFLrmzZtGnV5Sty6\nlNK0crer6PRESmkhMBUY1rgsIqL0/KEKdrU9+bQFwDPAdqVlHy89bgP+WPr+xUoySqvqRz+CvfaC\nUaNgxoyi00hSx9GS0xPnAxMjYipLL7lcA5gIEBFnAxuklA4vPT+JfDXEU0BP4BjgM8CeACml94Cn\nm/6AiPhPfilNb0E+aZV07ZpvaLXTTrDvvvDww/n22pLU2VV8yWVK6QZgHHAm8BjwMWB4SmluaZV+\nQP8mm/Qgz+vwBHAf+ajCsJTSfS1OLbWxD38Ybr0VXngBDj8cliwpOpEkFa9FAyFTSpcAlyzntSOb\nPT+XPGFTJfs/cuVrSW1r661h0qR8tOGss+B73ys6kSQVy3tPSCswYgSccQZ8//v5yIMkdWaWBmkl\nvvtd2G8/OPRQePrpla8vSbXK0iCtRJcucNVVsOmm+cjD6xXPJiJJtcHSIJWhV688Y+Rrr+XJnxYv\nLjqRJLU/S4NUps03hxtugLvugm9/u+g0ktT+LA1SBfbYA849F845J8/lIEmdSZvfe0KqNWPGwOOP\nw1e+Ah/9aL5nhSR1Bh5pkCoUAZdfDttum+dweOWVohNJUvuwNEgtsPrqMHkyLFgABxyQv0pSrbM0\nSC200UZw880wZQqcfHLRaSSp7VkapFUwdChcdBFcein84hdFp5GktuVASGkVHXtsHhh5wgn5fhVD\nhxadSJLahkcapFYwYQLsvDPsvz/MmlV0GklqG5YGqRX06AE33ZS/jhwJ77xTdCJJan2WBqmVrLtu\nvqLiySdh9GhIqehEktS6LA1SK6qrgyuugGuuyacsJKmWOBBSamUHH5wHRo4bB9ttl6eelqRa4JEG\nqQ2cfTbsuSeMGgX/+EfRaSSpdVgapDbQtSs0NMBHPgIjRsBbbxWdSJJWnaVBaiMf/jD85jfwwgtw\n+OGwZEnRiSRp1VgapDa0zTYwaRLccgucdVbRaSRp1VgapDY2YgSccQZ8//tw221Fp5GklrM0SO3g\nu9/Nkz4deig88UTRaSSpZSwNUjvo0gWuugo23TTP5TB6NMycWXQqSaqMpUFqJ2utBQ8/nC/HvOUW\n2GKLfJOrl14qOpkklcfSILWjNdfMkz49/zyceSZcfz1svjmcdBLMnl10OklaMUuDVIBeveCUU3J5\n+O534eqrYcAA+MY3YM6cotNJ0rJZGqQCfehDuTQ8/zx861v5vhWbbQbf/CbMnVt0Okl6P0uD1AH0\n6QOnn57Lw9ixcOmluTx8+9vw2mtFp5OkzNIgdSAf+Qj84Ae5PHzta3DBBbk8fP/78PrrRaeT1NlZ\nGqQOqG/ffJXF88/nyzN/+tNcHs48E+bNKzqdpM7K0iB1YOuuC+eem++UeeSR8KMf5fJw1lnw5ptF\np5PU2VgapCrQrx+MH5/LwyGH5CMOm20GP/mJd9CU1H4sDVIV2WADuPBC+Pvf4cAD4Xvfy5dqnnce\nzJ9fdDpJtc7SIFWh/v3hkkvgb3+DfffNcz4MGAATJsA77xSdTlKtsjRIVWyTTeDnP4dnn4UvfCHP\nNrn55nDRRfDee0Wnk1RrLA1SDRgwAH71K5g+HfbYI09LvcUWcNllsGBB0ekk1QpLg1RDttwyT0n9\n9NOw665w/PF52S9/CQsXFp1OUrWzNEg1aOBAuPZaePJJ2HlnOOaYvGziRFi0qOh0kqqVpUGqYVtv\nDb/+NTzxBOywQ57rYdAguOYaWLy46HSSqo2lQeoEttsObr4Zpk3LReKww2CbbaChwfIgqXyWBqkT\n2WEHuPVW+L//ywMlDz4YPvYxuPFGWLKk6HSSOjpLg9QJ7bgj/O//wpQpec6HAw+E7beHyZMhpaLT\nSeqoLA1SJ/bJT8Idd8ADD+T7XOy3H9TVwW23WR4kfZClQRJDh8Ldd8Of/gQf+hCMGAE77QS//73l\nQdJSlgZJ/7XrrnDvvXDPPbDaarD33jBkCNx1l+VBkqVBUjMR8NnPwp//nMtCBAwfDp/+dC4Tlgep\n87I0SFqmCNhzT3jooXya4r338hTVu++eT2NI6nwsDZJWKAI+/3l45JE8QPLNN3NxGDYMHnyw6HSS\n2pOlQVJZIuCLX4SpU/Olma++Cp/6VD51MWVK0ekktQdLg6SKRMC++8Jjj+VJoWbNyoMl994bHn20\n6HSS2lKLSkNEnBARz0fEOxExJSI+sYJ1d4uIJc0eiyNi3SbrHB0R90fEv0uPP6xon5KK16ULHHBA\nvq9FQwMRu38dAAAL2ElEQVTMmAGf+ATss08uFJJqT7dKN4iIUcB5wLHAI8AY4M6I2Cql9OpyNkvA\nVsCb/12Q0itNXt8NuA54CHgXOAW4KyK2TinNrjSjpPbTtSscdBB86Utw/fVwxhkweHC+t0WPHvn1\nbt3y18ZH8+flrNOSbdpjv716wVprFf2nILWPiksDuSRcnlK6GiAijgP2Bo4CzlnBdnNTSm8s64WU\n0pebPo+Io4H9gWHApBZklNTOunaFQw6BUaPguuvywMnFiz/4WLTog88XLFj5Oivbz/K2a49LRPv0\ngU02gU03XfbXj3wkn9aRql1FpSEiugN1wI8al6WUUkTcDQxZ0abA4xHRE3gSOD2l9NAK1l8T6A78\nu5J8korXrVu+i+ZhhxWdJEtp1QvJitZ54w345z/z44UX4A9/yN/Pn780w5prfrBINP1+vfUsFaoO\nlR5p6At0BeY0Wz4HGLicbWYDo4FHgdWAY4D7ImKnlNLjy9nmJ8BLwN0V5pOk94nIRaZbS46rtlBK\n+eqSxiLR9OsDD8CkSblsNOrZEzbeeNlHKTbZBDbYIB/JkYrW5v8bpZSeA55rsmhKRGxOPs1xePP1\nI+IU4EBgt5TSgpXtf8yYMfTu3ft9y+rr66mvr1+l3JLUUhGwzjr5seOOy17nP/9ZWiSaloqpU+GW\nW+C115au261bvhvp8o5WbLQRdO/e5m9LVaqhoYGGhob3LZs3b16L9hWpghN+pdMT84H9U0q3NVk+\nEeidUhpZ5n7OAYamlIY2Wz4O+DYwLKW0wvHXETEYmDp16lQGDx5c9nuQpGrw1lsfLBRNS8bLLy9d\nt0sX2HDDD572aPy68cb5aEZnt3AhvP320sf8+R/8fuFC6Ncvl7SNNoJmv5PWjGnTplFXVwdQl1Ka\nVu52FR1pSCktjIip5AGKtwFERJSe/6yCXW1PPm3xXxHxTeBUYK+VFQZJqnW9euUrULbZZtmvv/su\nzJy57ELxpz/BSy+9fxBov37LP/2xySb55xVt0aIPfpgv78O9nOfNX1u0qPJMa62Vy8NGGy0tEs2/\n792784xJacnpifOBiaXy0HjJ5RrARICIOBvYIKV0eOn5ScDzwFNAT/KYhs8AezbuMCK+BZwB1AMz\nI2K90ktvpZTebkFGSappPXvCVlvlx7IsWJAn3lrW0YpHHoEXX3z/h+jaay97kGbj1z598sDPcj/A\nW/LhvmClJ6Sz1VfPg0sbH2ussfT73r3zGJBlvbay52uskU8FzZ6d/9vNmpX/OzV+/+STcMcd+fWm\nhaxXrxWXiv79a6dYVFwaUko3RERf4ExgPeBxYHhKaW5plX5A/yab9CDP67AB+dTGE+TTD/c3Wec4\n8tUSNzX7cWeUfo4kqQI9esCAAfmxLIsXw7/+tezTH//7v/koxnvvLV2/e/d86L4cPXsu/wN6rbXy\nUY9KP8wbv1999Xw6pi01FqflWbhwabFoWipefBGefjrfHXb2bFiyZOk2a6654lKx0Ua5mHX0YlHR\nmIaOxDENktR2liyBV15ZWiTmzs0f3iv7cF9jDa/0gFwsXn552cWi8ft//euDxaKxSDQtE00Lxoc/\n3DrFol3GNEiSOocuXfIRgX79YOedi05Tfbp3zx/2/fvne7Msy6JFHywWjV+few7uueeDxWKNNVZ8\ntKJ//9YrFstiaZAkqQDdui390F9eMVu0CObMWfbRir/9Df74x1wsFi9eus3qq698jEVLTzJYGiRJ\n6qC6dcuX02644fLXWbx4+cVixgy4775cLJoOfF1ttRbmadlmkiSpI+jaNV8xssEG8MlPLnudxYvz\nGJXGMvHQQ3DeeZX/rDYegypJkorWtSusvz7stBPstx8cfHDL9mNpkCRJZbE0SJKkslgaJElSWSwN\nkiSpLJYGSZJUFkuDJEkqi6VBkiSVxdIgSZLKYmmQJEllsTRIkqSyWBokSVJZLA2SJKkslgZJklQW\nS4MkSSqLpaEKNDQ0FB2hXfg+a4vvs7b4PgWWhqrQWf4S+z5ri++ztvg+BZYGSZJUJkuDJEkqi6VB\nkiSVpVvRAVZBT4Dp06cXnaPNzZs3j2nTphUdo835PmuL77O2+D5rS5PPzp6VbBcppdZP0w4i4mDg\n2qJzSJJUxQ5JKV1X7srVXBrWBoYDLwDvFptGkqSq0hPYFLgzpfRauRtVbWmQJEnty4GQkiSpLJYG\nSZJUFkuDJEkqi6VBkiSVxdIgSZLKUpWlISJOiIjnI+KdiJgSEZ8oOlNri4hPR8RtEfFSRCyJiH2K\nztTaIuLUiHgkIt6IiDkRMTkitio6V1uIiOMi4q8RMa/0eCgiPld0rrYUEaeU/u6eX3SW1hYRp5Xe\nW9PH00XnagsRsUFEXBMRr0bE/NLf48FF52pNpc+T5n+eSyLiwqKztaaI6BIRP4iIf5T+LP8eEd+t\nZB9VVxoiYhRwHnAasAPwV+DOiOhbaLDWtybwOHA8UKvXxX4auBD4JLAH0B24KyJWLzRV23gR+BYw\nGKgD/gjcGhGDCk3VRkpF/ljy/5+16klgPaBf6fGpYuO0vojoAzwIvEeeF2cQMBZ4vchcbWBHlv45\n9gP2JP+7e0ORodrAKcBo8ufKR4FvAt+MiBPL3UHVzdMQEVOAv6SUTio9D/I/yD9LKZ1TaLg2EhFL\ngH1TSrcVnaUtlYrfK8CuKaUHis7T1iLiNWBcSunKorO0pojoBUwFvgp8D3gspfSNYlO1rog4DRiR\nUqqp37ibi4gfA0NSSrsVnaU9RcQE4AsppZo68hkRvwVeTikd02TZTcD8lNJh5eyjqo40RER38m9p\n9zQuS7n13A0MKSqXWk0fcrv/d9FB2lLpEOFBwBrAw0XnaQMXA79NKf2x6CBtbMvS6cMZETEpIvoX\nHagNfBF4NCJuKJ1CnBYRRxcdqi2VPmcOAa4oOksbeAgYFhFbAkTEx4GhwO/L3UG13bCqL9AVmNNs\n+RxgYPvHUWspHTGaADyQUqrVc8PbkktCT+BNYGRK6ZliU7WuUhnanny4t5ZNAY4AngXWB04H7o+I\nbVNKbxeYq7UNIB8xOg84C9gJ+FlEvJdSuqbQZG1nJNAbuKroIG3gx8CHgGciYjH5wMF3UkrXl7uD\naisNql2XAFuTW2+tegb4OPkfpAOAqyNi11opDhGxEbn47ZFSWlh0nraUUrqzydMnI+IR4J/AgUAt\nnW7qAjySUvpe6flfS+X3OKBWS8NRwO0ppZeLDtIGRgEHAwcBT5ML/gUR8a9yS2C1lYZXgcXkwUdN\nrQfU4h9wpxARFwFfAD6dUppddJ62klJaBPyj9PSxiNgJOIn8m1wtqAPWAaaVjhxBPjK4a2mg1Wqp\n2gZRlSmlNC8ingO2KDpLK5sNTG+2bDqwXwFZ2lxEbEwelL1v0VnayDnA2SmlG0vPn4qITYFTKbME\nVtWYhtJvL1OBYY3LSv84DSOfq1GVKRWGEcBnUkozi87TzroAqxUdohXdDWxH/u3l46XHo8Ak4OO1\nWhjgv4M/tyB/yNaSB/ngqd+B5KMqtego8unuss/xV5k1yL94N7WECrpAtR1pADgfmBgRU4FHgDHk\n/xATiwzV2iJiTfI/Qo2/sQ0oDVr5d0rpxeKStZ6IuASoB/YB3o6IxiNI81JKNXW784j4EXA7MBNY\nizzQajdgryJztabSufz3jUeJiLeB11JKzX9brWoRcS7wW/KH54bAGcBCoKHIXG1gPPBgRJxKvvzw\nk8DRwDEr3KoKlX4BPQKYmFJaUnCctvJb4LsRMQt4inwJ+Bjgl+XuoOpKQ0rphtKleWeST0s8DgxP\nKc0tNlmr2xG4l3w1QSIPRII8OOeookK1suPI7+2+ZsuPBK5u9zRta13yn936wDzgCWCvTnCFQa0e\nXdgIuA5YG5gLPADsnFJ6rdBUrSyl9GhEjCQPoPse8DxwUiUD56rIHkB/amtMSnMnAj8gX+G0LvAv\n4NLSsrJU3TwNkiSpGFU1pkGSJBXH0iBJkspiaZAkSWWxNEiSpLJYGiRJUlksDZIkqSyWBkmSVBZL\ngyRJKoulQZIklcXSIEmSymJpkCRJZfn/eYPAmTpIKHUAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "training_losses = train_network(1,num_steps)\n", "plt.plot(training_losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [conda env:tensorflow2]", "language": "python", "name": "conda-env-tensorflow2-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 1 }