{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# after https://github.com/hans/ipython-notebooks/blob/master/tf/TF%20tutorial.ipynb" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/tmp/tmpmn29ne8n\n" ] } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "import tempfile\n", "logdir = tempfile.mkdtemp()\n", "print(logdir)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sess = tf.InteractiveSession()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "seq_length = 5 # number of timesteps\n", "batch_size = 64 \n", "vocab_size = 7\n", "embedding_size = 50\n", "state_size = 100" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# tensors are input as a list of size (number of timesteps)\n", "enc_inp = [tf.placeholder(tf.int32, shape=(None,), name=\"inp%i\" % t) for t in range(seq_length)]\n", "labels = [tf.placeholder(tf.int32, shape=(None,), name=\"labels%i\" % t) for t in range(seq_length)]\n", "enc_inp" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights = [tf.ones_like(labels_t, dtype=tf.float32) for labels_t in labels]\n", "weights" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Decoder input: prepend some \"GO\" token and drop the final token of the encoder input\n", "dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name=\"GO\")] + enc_inp[:-1])\n", "dec_inp" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Initial memory value for recurrence.\n", "prev_mem = tf.zeros((batch_size, state_size))\n", "prev_mem" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use different kinds of RNN cell for seq2seq." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cell = tf.nn.rnn_cell.GRUCell(state_size)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "tf.nn.seq2seq.embedding_rnn_seq2seq:\n", "\n", "> Embedding RNN sequence-to-sequence model.\n", " This model first embeds encoder_inputs by a newly created embedding (of shape\n", " [num_encoder_symbols x input_size]). Then it runs an RNN to encode\n", " embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs\n", " by another newly created embedding (of shape [num_decoder_symbols x\n", " input_size]). Then it runs RNN decoder, initialized with the last\n", " encoder state, on embedded decoder_inputs." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "([,\n", " ,\n", " ,\n", " ,\n", " ],\n", " )" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# inputs will be embedded, so have to specify maximum number of symbols that can appear (vocab_size)\n", "dec_outputs, dec_state = tf.nn.seq2seq.embedding_rnn_seq2seq(\n", " enc_inp, dec_inp, cell, vocab_size, vocab_size, embedding_size)\n", "dec_outputs, dec_state" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loss = tf.nn.seq2seq.sequence_loss(dec_outputs, labels, weights, vocab_size)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.scalar_summary(\"loss\", loss)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_state[1])))\n", "tf.scalar_summary(\"magnitude at t=1\", magnitude)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "summary_op = tf.merge_all_summaries()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "learning_rate = 0.05\n", "momentum = 0.9\n", "optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)\n", "train_op = optimizer.minimize(loss)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/tmp/tmpo_41xu3j\n" ] } ], "source": [ "logdir = tempfile.mkdtemp()\n", "print(logdir)\n", "summary_writer = tf.train.SummaryWriter(logdir, sess.graph)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sess.run(tf.initialize_all_variables())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train network, step-by-step" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate input:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([3, 4, 2, 6, 5]),\n", " array([5, 3, 4, 6, 1]),\n", " array([4, 2, 6, 5, 1]),\n", " array([2, 4, 5, 6, 1]),\n", " array([5, 4, 1, 2, 3])]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = [np.random.choice(vocab_size, size=(seq_length,), replace=False)\n", " for _ in range(batch_size)]\n", "Y = X[:]\n", "X[:5]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[3, 5, 4, 2, 5, 3, 1, 2, 3, 3, 0, 6, 4, 0, 5, 1, 4, 4, 1, 1, 0, 3,\n", " 4, 3, 6, 4, 4, 0, 0, 2, 6, 5, 3, 0, 5, 2, 4, 5, 1, 3, 5, 6, 2, 0,\n", " 2, 3, 0, 6, 2, 6, 1, 1, 4, 5, 4, 0, 0, 2, 3, 0, 5, 2, 1, 0],\n", " [4, 3, 2, 4, 4, 1, 2, 3, 0, 1, 1, 4, 5, 2, 0, 5, 0, 2, 4, 6, 4, 5,\n", " 6, 1, 2, 2, 3, 1, 5, 5, 2, 2, 5, 3, 6, 1, 5, 3, 5, 1, 0, 1, 1, 5,\n", " 5, 0, 5, 2, 5, 4, 5, 6, 2, 6, 1, 2, 2, 4, 0, 2, 6, 1, 3, 3],\n", " [2, 4, 6, 5, 1, 0, 3, 1, 1, 2, 4, 0, 1, 4, 1, 6, 1, 3, 2, 5, 1, 4,\n", " 3, 2, 1, 3, 2, 3, 6, 1, 1, 0, 4, 4, 2, 6, 1, 4, 0, 2, 1, 0, 3, 4,\n", " 3, 6, 4, 4, 6, 5, 4, 5, 1, 4, 3, 5, 5, 0, 6, 1, 3, 5, 4, 4],\n", " [6, 6, 5, 6, 2, 6, 5, 5, 6, 5, 5, 2, 6, 6, 4, 4, 3, 1, 6, 2, 5, 6,\n", " 0, 5, 0, 6, 0, 2, 4, 3, 3, 6, 2, 2, 3, 3, 3, 1, 6, 5, 4, 2, 4, 3,\n", " 0, 2, 2, 5, 3, 0, 0, 4, 6, 3, 2, 6, 3, 5, 2, 5, 1, 3, 2, 1],\n", " [5, 1, 1, 1, 3, 4, 6, 4, 5, 4, 6, 1, 0, 3, 2, 0, 6, 6, 0, 4, 2, 2,\n", " 2, 0, 3, 0, 6, 4, 1, 0, 4, 4, 1, 5, 1, 0, 0, 0, 4, 4, 2, 3, 0, 1,\n", " 1, 4, 6, 1, 0, 2, 3, 3, 5, 1, 5, 3, 1, 1, 1, 4, 4, 6, 6, 5]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = np.array(X).T\n", "Y = np.array(Y).T\n", "X[:5]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([3, 5, 4, 2, 5, 3, 1, 2, 3, 3, 0, 6, 4, 0, 5, 1, 4, 4, 1, 1, 0, 3, 4,\n", " 3, 6, 4, 4, 0, 0, 2, 6, 5, 3, 0, 5, 2, 4, 5, 1, 3, 5, 6, 2, 0, 2, 3,\n", " 0, 6, 2, 6, 1, 1, 4, 5, 4, 0, 0, 2, 3, 0, 5, 2, 1, 0]),\n", " array([4, 3, 2, 4, 4, 1, 2, 3, 0, 1, 1, 4, 5, 2, 0, 5, 0, 2, 4, 6, 4, 5, 6,\n", " 1, 2, 2, 3, 1, 5, 5, 2, 2, 5, 3, 6, 1, 5, 3, 5, 1, 0, 1, 1, 5, 5, 0,\n", " 5, 2, 5, 4, 5, 6, 2, 6, 1, 2, 2, 4, 0, 2, 6, 1, 3, 3]),\n", " array([2, 4, 6, 5, 1, 0, 3, 1, 1, 2, 4, 0, 1, 4, 1, 6, 1, 3, 2, 5, 1, 4, 3,\n", " 2, 1, 3, 2, 3, 6, 1, 1, 0, 4, 4, 2, 6, 1, 4, 0, 2, 1, 0, 3, 4, 3, 6,\n", " 4, 4, 6, 5, 4, 5, 1, 4, 3, 5, 5, 0, 6, 1, 3, 5, 4, 4]),\n", " array([6, 6, 5, 6, 2, 6, 5, 5, 6, 5, 5, 2, 6, 6, 4, 4, 3, 1, 6, 2, 5, 6, 0,\n", " 5, 0, 6, 0, 2, 4, 3, 3, 6, 2, 2, 3, 3, 3, 1, 6, 5, 4, 2, 4, 3, 0, 2,\n", " 2, 5, 3, 0, 0, 4, 6, 3, 2, 6, 3, 5, 2, 5, 1, 3, 2, 1]),\n", " array([5, 1, 1, 1, 3, 4, 6, 4, 5, 4, 6, 1, 0, 3, 2, 0, 6, 6, 0, 4, 2, 2, 2,\n", " 0, 3, 0, 6, 4, 1, 0, 4, 4, 1, 5, 1, 0, 0, 0, 4, 4, 2, 3, 0, 1, 1, 4,\n", " 6, 1, 0, 2, 3, 3, 5, 1, 5, 3, 1, 1, 1, 4, 4, 6, 6, 5])]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[X[t] for t in range(seq_length)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Feed input:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [], "source": [ "feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One training step:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(1.9546013,\n", " b'\\n\\x0b\\n\\x04loss\\x15`0\\xfa?\\n\\x17\\n\\x10magnitude at t=1\\x15/\\xc5\\x8a?')" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)\n", "loss_t, summary" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test case" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([5, 0, 4, 2, 1]),\n", " array([1, 2, 5, 3, 0]),\n", " array([3, 0, 5, 1, 4]),\n", " array([6, 3, 5, 2, 1]),\n", " array([6, 4, 1, 5, 2]),\n", " array([1, 5, 3, 0, 6]),\n", " array([4, 6, 0, 2, 3]),\n", " array([6, 4, 0, 5, 1]),\n", " array([0, 2, 1, 6, 4]),\n", " array([4, 3, 1, 5, 0])]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_batch = [np.random.choice(vocab_size, size=(seq_length,), replace=False)\n", " for _ in range(10)]\n", "X_batch" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[5, 1, 3, 6, 6, 1, 4, 6, 0, 4],\n", " [0, 2, 0, 3, 4, 5, 6, 4, 2, 3],\n", " [4, 5, 5, 5, 1, 3, 0, 0, 1, 1],\n", " [2, 3, 1, 2, 5, 0, 2, 5, 6, 5],\n", " [1, 0, 4, 1, 2, 6, 3, 1, 4, 0]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_batch = np.array(X_batch).T\n", "X_batch" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([[ 0.17844655, 0.11353172, -0.18144946, 0.01085662, 0.2627669 ,\n", " -0.19861746, -0.08941379],\n", " [ 0.21098135, -0.08187042, 0.01305156, -0.05247057, 0.07250485,\n", " -0.23508257, -0.2811504 ],\n", " [ 0.13883567, 0.09521072, -0.11679886, 0.00680578, 0.16948652,\n", " -0.12717277, -0.0371674 ],\n", " [ 0.21993272, 0.00940304, -0.16575341, -0.16260277, 0.1767858 ,\n", " -0.18685289, -0.08371995],\n", " [ 0.10060885, 0.01945375, -0.21006839, -0.14021412, -0.01953002,\n", " -0.18472196, -0.08986312],\n", " [ 0.17035802, -0.07102758, 0.0302645 , -0.25397092, -0.00265401,\n", " -0.12531452, -0.2037717 ],\n", " [ 0.24822755, -0.21595982, 0.00301015, -0.27651039, 0.10579222,\n", " -0.24226177, -0.21743613],\n", " [ 0.27941915, 0.13427791, -0.13733003, -0.05824878, 0.17526907,\n", " -0.10152674, -0.1523411 ],\n", " [ 0.04663948, -0.0393581 , -0.08266062, -0.11853192, 0.05972272,\n", " -0.11571028, -0.0639744 ],\n", " [ 0.18015689, -0.02721097, -0.00587304, -0.02247566, 0.12168466,\n", " -0.14741473, -0.15037411]], dtype=float32),\n", " array([[ 0.10395906, 0.18637255, -0.11520418, 0.0176604 , 0.10534781,\n", " -0.07860883, 0.01924768],\n", " [ 0.15846533, -0.07039611, -0.04755091, -0.07138188, -0.02692247,\n", " -0.18470216, -0.17578526],\n", " [ 0.16245708, 0.15115945, -0.045459 , 0.0340042 , 0.1339983 ,\n", " -0.09590427, 0.03113824],\n", " [ 0.10051248, -0.02949939, -0.06673641, -0.13614778, 0.16801459,\n", " -0.17466651, -0.063809 ],\n", " [ 0.03787023, -0.01964972, -0.11155554, -0.09965882, 0.03461789,\n", " -0.16733888, -0.08309884],\n", " [ 0.09131334, -0.06725299, -0.04503515, -0.23164834, -0.08556231,\n", " -0.11044674, -0.13894588],\n", " [ 0.16687544, -0.14045513, -0.03331609, -0.10184428, 0.17324263,\n", " -0.1585528 , -0.05186675],\n", " [ 0.18770847, 0.07122251, -0.04060577, -0.03737558, 0.17788059,\n", " -0.08480155, -0.10424838],\n", " [ 0.1132549 , -0.00362585, 0.04683149, 0.01941368, 0.17256692,\n", " 0.04148631, -0.14065626],\n", " [ 0.12920758, 0.01740728, -0.02748516, 0.07680587, 0.19532834,\n", " -0.09230802, -0.02468909]], dtype=float32),\n", " array([[ 1.33050695e-01, 1.48067713e-01, -9.75858513e-03,\n", " 1.31673515e-01, 1.92209393e-01, 6.93723485e-02,\n", " -4.91486937e-02],\n", " [ -5.59313111e-02, 5.76260835e-02, -6.22149697e-03,\n", " -5.92892915e-02, -6.83557987e-02, -1.39937550e-01,\n", " -1.34125471e-01],\n", " [ 2.13978752e-01, 1.23260379e-01, 4.88165878e-02,\n", " 1.21520363e-01, 2.33814687e-01, 5.97248301e-02,\n", " -5.82022667e-02],\n", " [ 9.52694789e-02, 5.56701086e-02, -4.42465693e-02,\n", " -8.04473609e-02, 9.91110578e-02, -1.50191337e-01,\n", " 8.31928663e-03],\n", " [ -1.48642331e-03, 1.77686252e-02, -1.26366407e-01,\n", " -1.18158665e-02, 1.07669294e-01, -1.01451002e-01,\n", " 1.92329716e-02],\n", " [ 4.51667160e-02, 3.05926315e-02, -4.32675816e-02,\n", " -1.06809497e-01, -1.46296367e-01, -6.12212159e-02,\n", " -2.95996219e-02],\n", " [ 1.39960676e-01, -1.28353313e-01, 3.92415486e-02,\n", " -3.38840075e-02, 1.64977491e-01, -1.28964454e-01,\n", " -8.23383965e-03],\n", " [ 1.13442115e-01, 8.32194015e-02, -5.56406938e-02,\n", " 2.53340248e-02, 2.11977690e-01, -3.01985200e-02,\n", " 7.91771710e-03],\n", " [ -1.17359094e-01, 9.14113894e-02, 1.13254704e-01,\n", " -3.58526148e-02, 9.18745771e-02, 4.33049612e-02,\n", " -9.80217829e-02],\n", " [ 1.86070725e-01, 8.21213499e-02, -1.28050480e-04,\n", " 1.13477461e-01, 1.50830865e-01, -8.51025060e-02,\n", " 3.86973992e-02]], dtype=float32),\n", " array([[ 0.04973944, 0.10831716, -0.00375871, 0.17339239, 0.22739363,\n", " 0.06614812, 0.05624339],\n", " [-0.04072684, 0.14192866, -0.01291926, -0.00893061, -0.14943664,\n", " -0.08095054, -0.00498922],\n", " [ 0.1481124 , 0.17895575, 0.07431188, 0.10259738, 0.09361922,\n", " 0.08516876, 0.05082116],\n", " [ 0.05775284, 0.13802256, -0.04122018, -0.05294784, -0.02876745,\n", " -0.0752691 , 0.09462182],\n", " [ 0.00873535, -0.01149929, -0.12302497, -0.06627826, -0.02505408,\n", " -0.05816741, 0.03803293],\n", " [ 0.05197376, 0.07711638, -0.02642946, -0.01482341, -0.1043431 ,\n", " -0.06400553, 0.02498893],\n", " [ 0.2129983 , -0.06768389, 0.08495383, 0.10012042, 0.22794855,\n", " 0.02708369, -0.08005156],\n", " [ 0.19821872, 0.08864162, 0.02410168, 0.11796612, 0.26989272,\n", " 0.10462284, -0.07330685],\n", " [-0.09578598, 0.02733124, 0.08310014, -0.09925178, -0.03786328,\n", " 0.05727703, -0.04634201],\n", " [ 0.1499618 , 0.02167371, -0.02839148, 0.04336146, 0.03068283,\n", " -0.03333036, 0.056478 ]], dtype=float32),\n", " array([[-0.13720484, 0.17640325, 0.06879292, 0.07539334, 0.12553324,\n", " 0.05548202, 0.04008906],\n", " [ 0.00208345, 0.16644646, 0.00865782, 0.03311532, -0.12878664,\n", " -0.0752397 , 0.04306988],\n", " [ 0.07093671, 0.08025096, 0.04515825, 0.02599708, -0.02427215,\n", " 0.0889031 , 0.05317469],\n", " [-0.17283508, 0.18637198, 0.01248625, -0.08382478, -0.0740777 ,\n", " -0.05373616, 0.04355513],\n", " [ 0.01205936, 0.07759147, -0.10704872, 0.00033424, -0.11591723,\n", " -0.00552052, 0.11702637],\n", " [ 0.09295554, 0.04739372, 0.02949173, 0.11420977, 0.05874372,\n", " 0.07074399, -0.06921744],\n", " [-0.01032462, 0.05197939, 0.12150411, 0.03734748, 0.10758823,\n", " 0.03083727, -0.04855074],\n", " [ 0.16330148, 0.16301754, 0.03916627, 0.10394222, 0.10825945,\n", " 0.11118951, 0.04324638],\n", " [-0.07732442, -0.03749544, 0.13659573, -0.06185998, 0.01066446,\n", " 0.02512036, -0.01716892],\n", " [ 0.10852097, 0.08285542, -0.02328821, 0.10114764, -0.06120555,\n", " 0.02239372, 0.13511136]], dtype=float32)]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feed_dict = {enc_inp[t]: X_batch[t] for t in range(seq_length)}\n", "dec_outputs_batch = sess.run(dec_outputs, feed_dict)\n", "dec_outputs_batch" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[array([4, 0, 4, 0, 0, 0, 0, 0, 4, 0]),\n", " array([1, 0, 0, 4, 0, 0, 4, 0, 4, 4]),\n", " array([4, 1, 4, 4, 4, 0, 4, 4, 2, 0]),\n", " array([4, 1, 1, 1, 6, 1, 4, 4, 2, 0]),\n", " array([1, 1, 5, 1, 6, 3, 2, 0, 2, 6])]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[logits_t.argmax(axis=1) for logits_t in dec_outputs_batch]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### training function" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def train_batch(batch_size):\n", " X = [np.random.choice(vocab_size, size=(seq_length,), replace=False) for _ in range(batch_size)]\n", " Y = X[:]\n", " \n", " # Dimshuffle to seq_len * batch_size\n", " X = np.array(X).T\n", " Y = np.array(Y).T\n", "\n", " feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}\n", " feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})\n", "\n", " _, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)\n", " return loss_t, summary" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true }, "outputs": [], "source": [ "for t in range(500):\n", " loss_t, summary = train_batch(batch_size)\n", " summary_writer.add_summary(summary, t)\n", "summary_writer.flush()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [conda env:tensorflow2]", "language": "python", "name": "conda-env-tensorflow2-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 1 }