{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/\n", "# http://learningtensorflow.com/index.html\n", "# http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "from tensorflow.contrib import rnn\n", "import pprint\n", "pp = pprint.PrettyPrinter(indent=4)\n", "sess = tf.InteractiveSession()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# One hot encoding for each char in 'hello'\n", "h = [1, 0, 0, 0]\n", "e = [0, 1, 0, 0]\n", "l = [0, 0, 1, 0]\n", "o = [0, 0, 0, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![image](https://cloud.githubusercontent.com/assets/901975/23348727/cc981856-fce7-11e6-83ea-4b187473466b.png)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 2\n", "array([[[ 1., 0., 0., 0.]]], dtype=float32)\n", "array([[[-0.42409304, 0.64651132]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('one_cell') as scope:\n", " # One cell RNN input_dim (4) -> output_dim (2)\n", " hidden_size = 2\n", " cell = tf.contrib.rnn.BasicRNNCell(num_units=hidden_size)\n", " print(cell.output_size, cell.state_size)\n", "\n", " x_data = np.array([[h]], dtype=np.float32) # x_data = [[[1,0,0,0]]]\n", " pp.pprint(x_data)\n", " outputs, _states = tf.nn.dynamic_rnn(cell, x_data, dtype=tf.float32)\n", "\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![image](https://cloud.githubusercontent.com/assets/901975/23383634/649efd0a-fd82-11e6-925d-8041242743b0.png)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 5, 4)\n", "array([[[ 1., 0., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 0., 1.]]], dtype=float32)\n", "array([[[ 0.19709368, 0.24918222],\n", " [-0.11721198, 0.1784237 ],\n", " [-0.35297349, -0.66278851],\n", " [-0.70915914, -0.58334434],\n", " [-0.38886023, 0.47304463]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('two_sequances') as scope:\n", " # One cell RNN input_dim (4) -> output_dim (2). sequence: 5\n", " hidden_size = 2\n", " cell = tf.contrib.rnn.BasicRNNCell(num_units=hidden_size)\n", " x_data = np.array([[h, e, l, l, o]], dtype=np.float32)\n", " print(x_data.shape)\n", " pp.pprint(x_data)\n", " outputs, states = tf.nn.dynamic_rnn(cell, x_data, dtype=tf.float32)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![image](https://cloud.githubusercontent.com/assets/901975/23383681/9943a9fc-fd82-11e6-8121-bd187994e249.png)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 1., 0., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 0., 1.]],\n", "\n", " [[ 0., 1., 0., 0.],\n", " [ 0., 0., 0., 1.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.]],\n", "\n", " [[ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.]]], dtype=float32)\n", "array([[[-0.0173022 , -0.12929453],\n", " [-0.14995177, -0.23189341],\n", " [ 0.03294011, 0.01962204],\n", " [ 0.12852104, 0.12375218],\n", " [ 0.13597946, 0.31746736]],\n", "\n", " [[-0.15243632, -0.14177315],\n", " [ 0.04586344, 0.12249056],\n", " [ 0.14292534, 0.15872268],\n", " [ 0.18998367, 0.21004884],\n", " [ 0.21788891, 0.24151592]],\n", "\n", " [[ 0.10713603, 0.11001928],\n", " [ 0.17076059, 0.1799853 ],\n", " [-0.03531617, 0.08993293],\n", " [-0.1881337 , -0.08296411],\n", " [-0.00404597, 0.07156041]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('3_batches') as scope:\n", " # One cell RNN input_dim (4) -> output_dim (2). sequence: 5, batch 3\n", " # 3 batches 'hello', 'eolll', 'lleel'\n", " x_data = np.array([[h, e, l, l, o],\n", " [e, o, l, l, l],\n", " [l, l, e, e, l]], dtype=np.float32)\n", " pp.pprint(x_data)\n", " \n", " hidden_size = 2\n", " cell = rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True)\n", " outputs, _states = tf.nn.dynamic_rnn(\n", " cell, x_data, dtype=tf.float32)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 1., 0., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 0., 1.]],\n", "\n", " [[ 0., 1., 0., 0.],\n", " [ 0., 0., 0., 1.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.]],\n", "\n", " [[ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.]]], dtype=float32)\n", "array([[[-0.1560633 , -0.15812504],\n", " [-0.12457105, 0.00623323],\n", " [-0.12050693, -0.04313403],\n", " [-0.13090043, -0.08644461],\n", " [-0.00809618, 0.01956913]],\n", "\n", " [[-0.03981951, 0.08950347],\n", " [ 0.08891603, 0.13232458],\n", " [ 0.04445181, 0.12076475],\n", " [ 0. , 0. ],\n", " [ 0. , 0. ]],\n", "\n", " [[-0.03411232, -0.05148866],\n", " [-0.0663683 , -0.09379878],\n", " [-0.0947878 , 0.03129581],\n", " [-0.09255724, 0.1121003 ],\n", " [ 0. , 0. ]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('3_batches_dynamic_length') as scope:\n", " # One cell RNN input_dim (4) -> output_dim (5). sequence: 5, batch 3\n", " # 3 batches 'hello', 'eolll', 'lleel'\n", " x_data = np.array([[h, e, l, l, o],\n", " [e, o, l, l, l],\n", " [l, l, e, e, l]], dtype=np.float32)\n", " pp.pprint(x_data)\n", " \n", " hidden_size = 2\n", " cell = rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True)\n", " outputs, _states = tf.nn.dynamic_rnn(\n", " cell, x_data, sequence_length=[5,3,4], dtype=tf.float32)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 1., 0., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 0., 1.]],\n", "\n", " [[ 0., 1., 0., 0.],\n", " [ 0., 0., 0., 1.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.]],\n", "\n", " [[ 0., 0., 1., 0.],\n", " [ 0., 0., 1., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 1., 0., 0.],\n", " [ 0., 0., 1., 0.]]], dtype=float32)\n", "array([[[ 0.08037324, 0.09708502],\n", " [ 0.13482611, 0.22225909],\n", " [ 0.31230038, 0.21865457],\n", " [ 0.37461194, 0.23103678],\n", " [ 0.27929804, 0.19694683]],\n", "\n", " [[ 0.08168668, 0.16866113],\n", " [ 0.06738912, 0.16512491],\n", " [ 0.22980295, 0.25232255],\n", " [ 0.32049009, 0.25064784],\n", " [ 0.37890342, 0.24961403]],\n", "\n", " [[ 0.17865573, 0.09529682],\n", " [ 0.29475945, 0.15692782],\n", " [ 0.20178071, 0.26526704],\n", " [ 0.20977789, 0.3048915 ],\n", " [ 0.38907003, 0.26467156]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('initial_state') as scope:\n", " batch_size = 3\n", " x_data = np.array([[h, e, l, l, o],\n", " [e, o, l, l, l],\n", " [l, l, e, e, l]], dtype=np.float32)\n", " pp.pprint(x_data)\n", " \n", " # One cell RNN input_dim (4) -> output_dim (5). sequence: 5, batch: 3\n", " hidden_size=2\n", " cell = rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True)\n", " initial_state = cell.zero_state(batch_size, tf.float32)\n", " outputs, _states = tf.nn.dynamic_rnn(cell, x_data,\n", " initial_state=initial_state, dtype=tf.float32)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 0., 1., 2.],\n", " [ 3., 4., 5.],\n", " [ 6., 7., 8.],\n", " [ 9., 10., 11.],\n", " [ 12., 13., 14.]],\n", "\n", " [[ 15., 16., 17.],\n", " [ 18., 19., 20.],\n", " [ 21., 22., 23.],\n", " [ 24., 25., 26.],\n", " [ 27., 28., 29.]],\n", "\n", " [[ 30., 31., 32.],\n", " [ 33., 34., 35.],\n", " [ 36., 37., 38.],\n", " [ 39., 40., 41.],\n", " [ 42., 43., 44.]]], dtype=float32)\n" ] } ], "source": [ "# Create input data\n", "batch_size=3\n", "sequence_length=5\n", "input_dim=3\n", "\n", "x_data = np.arange(45, dtype=np.float32).reshape(batch_size, sequence_length, input_dim)\n", "pp.pprint(x_data) # batch, sequence_length, input_dim" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 2.45132565e-01, -2.26720855e-01, -9.13775116e-02,\n", " -2.41879746e-01, -1.91152543e-02],\n", " [ 3.88628483e-01, -4.98127311e-01, -5.63271251e-03,\n", " -3.56282324e-01, -2.27528125e-01],\n", " [ 4.56008732e-01, -4.88269717e-01, -2.17594192e-04,\n", " -3.37665766e-01, -4.73884165e-01],\n", " [ 4.68175769e-01, -3.43346447e-01, -6.79731329e-06,\n", " -2.79633790e-01, -6.64093494e-01],\n", " [ 4.53517973e-01, -1.80738673e-01, -1.71420652e-07,\n", " -2.21518710e-01, -7.89903104e-01]],\n", "\n", " [[ 3.25161479e-02, -3.76356095e-02, -4.70733952e-10,\n", " -1.30296305e-01, -5.30268252e-01],\n", " [ 5.25916032e-02, -2.39814539e-02, -1.25837223e-11,\n", " -1.25736266e-01, -8.11463356e-01],\n", " [ 6.13600165e-02, -1.12387687e-02, -2.15500184e-13,\n", " -1.06639825e-01, -9.18502569e-01],\n", " [ 6.30336851e-02, -4.92686359e-03, -2.89691854e-15,\n", " -8.40380341e-02, -9.58424091e-01],\n", " [ 6.10851720e-02, -2.18412513e-03, -3.52614224e-17,\n", " -6.41328543e-02, -9.75986063e-01]],\n", "\n", " [[ 1.38840056e-03, -9.62406339e-04, -5.32535570e-19,\n", " -3.17904465e-02, -6.70124114e-01],\n", " [ 2.27924506e-03, -4.67757985e-04, -6.79441632e-21,\n", " -3.17342579e-02, -9.22334671e-01],\n", " [ 2.62538693e-03, -2.18001660e-04, -6.31787844e-23,\n", " -2.62530502e-02, -9.81252372e-01],\n", " [ 2.63347290e-03, -1.02423473e-04, -4.19843804e-25,\n", " -1.99035294e-02, -9.93987679e-01],\n", " [ 2.49430514e-03, -4.83657859e-05, 3.01315592e-27,\n", " -1.47517556e-02, -9.97247159e-01]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('generated_data') as scope:\n", " # One cell RNN input_dim (3) -> output_dim (5). sequence: 5, batch: 3\n", " cell = rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)\n", " initial_state = cell.zero_state(batch_size, tf.float32)\n", " outputs, _states = tf.nn.dynamic_rnn(cell, x_data,\n", " initial_state=initial_state, dtype=tf.float32)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dynamic rnn: Tensor(\"MultiRNNCell/rnn/transpose:0\", shape=(3, 5, 5), dtype=float32)\n", "array([[[-0.00083933, -0.00015323, -0.00033779, 0.00080626, 0.00034253],\n", " [-0.00583554, 0.00118117, -0.00456069, 0.00529532, 0.0044482 ],\n", " [-0.01360667, 0.00480944, -0.01308769, 0.01040751, 0.0129455 ],\n", " [-0.02190471, 0.01003678, -0.023829 , 0.01284032, 0.02427816],\n", " [-0.02970388, 0.01583032, -0.03484126, 0.01160637, 0.03700195]],\n", "\n", " [[-0.00554347, 0.00480109, -0.00781848, 0.00443573, 0.00793266],\n", " [-0.01351871, 0.01181496, -0.01981916, 0.00755195, 0.02100384],\n", " [-0.02202674, 0.01874873, -0.03256253, 0.00714555, 0.03617828],\n", " [-0.03014897, 0.02444301, -0.04405964, 0.00303673, 0.05133053],\n", " [-0.03755479, 0.0285395 , -0.05335546, -0.00401108, 0.06519946]],\n", "\n", " [[-0.00576473, 0.00494064, -0.00812288, 0.00460927, 0.00833436],\n", " [-0.01375382, 0.01163429, -0.01992369, 0.00778724, 0.02128934],\n", " [-0.02198837, 0.01793554, -0.03203027, 0.0074151 , 0.03568052],\n", " [-0.02965781, 0.02297196, -0.04270053, 0.0034585 , 0.04962701],\n", " [-0.03656508, 0.02657485, -0.05117567, -0.0032303 , 0.06217124]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('MultiRNNCell') as scope:\n", " # Make rnn\n", " cell = rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)\n", " cell = rnn.MultiRNNCell([cell] * 3, state_is_tuple=True) # 3 layers\n", "\n", " # rnn in/out\n", " outputs, _states = tf.nn.dynamic_rnn(cell, x_data, dtype=tf.float32)\n", " print(\"dynamic rnn: \", outputs)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval()) # batch size, unrolling (time), hidden_size" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dynamic rnn: Tensor(\"dynamic_rnn/rnn/transpose:0\", shape=(3, 5, 5), dtype=float32)\n", "array([[[ -7.77400890e-03, -1.28562674e-01, 6.49908483e-02,\n", " -1.13532230e-01, 1.58623144e-01],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]],\n", "\n", " [[ 4.84958527e-08, -1.01035475e-05, -1.17934204e-03,\n", " -4.40897048e-01, 7.96075852e-04],\n", " [ 6.29663877e-09, -4.45833075e-06, -3.47505091e-04,\n", " -4.97627676e-01, 5.89297793e-04],\n", " [ 7.23061000e-10, -1.57925604e-06, -1.08419306e-04,\n", " -5.05440235e-01, 3.45583743e-04],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]],\n", "\n", " [[ 1.14308214e-14, -1.68528996e-10, -3.61130469e-06,\n", " -5.15216947e-01, 9.89917226e-07],\n", " [ 1.21262994e-15, -7.43059780e-11, -1.03032392e-06,\n", " -5.22881925e-01, 7.17830119e-07],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]]], dtype=float32)\n" ] } ], "source": [ "with tf.variable_scope('dynamic_rnn') as scope:\n", " cell = rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)\n", " outputs, _states = tf.nn.dynamic_rnn(cell, x_data, dtype=tf.float32,\n", " sequence_length=[1, 3, 2])\n", " # lentgh 1 for batch 1, lentgh 2 for batch 2\n", " \n", " print(\"dynamic rnn: \", outputs)\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(outputs.eval()) # batch size, unrolling (time), hidden_size" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "( array([[[ -6.91036135e-02, -8.50935578e-02, 1.11118190e-01,\n", " 9.62263346e-02, 5.16710952e-02],\n", " [ -7.26657510e-02, -8.60558674e-02, 1.28619120e-01,\n", " 1.28620356e-01, 1.27893269e-01],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]],\n", "\n", " [[ -4.88487422e-04, -1.19689852e-04, -1.31409352e-05,\n", " 6.31166389e-03, 2.36047944e-03],\n", " [ -1.79336712e-04, -8.58747517e-05, -1.47025303e-05,\n", " 3.67163564e-03, 7.57511822e-04],\n", " [ -5.09979509e-05, -4.93491898e-05, -1.48784202e-05,\n", " 1.73931674e-03, 2.30893813e-04],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]],\n", "\n", " [[ -7.59716784e-07, -3.01464169e-08, -3.27219529e-10,\n", " 1.28019688e-04, 6.67082486e-06],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00],\n", " [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n", " 0.00000000e+00, 0.00000000e+00]]], dtype=float32),\n", " array([[[-0.25029412, 0.00926055, 0.31124777, 0.34806553, 0.10433573],\n", " [-0.17661008, -0.03030178, 0.33094952, 0.53041464, 0.09067145],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ]],\n", "\n", " [[-0.16921706, -0.02368811, 0.204624 , 0.75845337, 0.06130206],\n", " [-0.15185349, -0.01456582, 0.1524749 , 0.76045614, 0.03964303],\n", " [-0.12275493, -0.00838695, 0.10047279, 0.7611382 , 0.01887123],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ]],\n", "\n", " [[-0.07493417, -0.00187773, 0.03880407, 0.76157337, 0.00505197],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ],\n", " [ 0. , 0. , 0. , 0. , 0. ]]], dtype=float32))\n", "( LSTMStateTuple(c=array([[ -6.96631312e-01, -2.70583540e-01, 1.34543031e-01,\n", " 1.22980523e+00, 2.07242519e-01],\n", " [ -2.66435528e+00, -5.09080105e-03, -1.48784211e-05,\n", " 2.98778200e+00, 2.32175487e-04],\n", " [ -9.66530681e-01, -2.15361561e-05, -3.27219557e-10,\n", " 1.00000000e+00, 6.67449740e-06]], dtype=float32), h=array([[ -7.26657510e-02, -8.60558674e-02, 1.28619120e-01,\n", " 1.28620356e-01, 1.27893269e-01],\n", " [ -5.09979509e-05, -4.93491898e-05, -1.48784202e-05,\n", " 1.73931674e-03, 2.30893813e-04],\n", " [ -7.59716784e-07, -3.01464169e-08, -3.27219529e-10,\n", " 1.28019688e-04, 6.67082486e-06]], dtype=float32)),\n", " LSTMStateTuple(c=array([[-0.45588541, 0.02305964, 0.60151225, 0.64902705, 0.18270651],\n", " [-0.20150106, -0.6873486 , 2.57025671, 0.99971294, 0.0740222 ],\n", " [-0.07767717, -0.83411807, 0.99998546, 1. , 0.00545057]], dtype=float32), h=array([[-0.25029412, 0.00926055, 0.31124777, 0.34806553, 0.10433573],\n", " [-0.16921706, -0.02368811, 0.204624 , 0.75845337, 0.06130206],\n", " [-0.07493417, -0.00187773, 0.03880407, 0.76157337, 0.00505197]], dtype=float32)))\n" ] } ], "source": [ "with tf.variable_scope('bi-directional') as scope:\n", " # bi-directional rnn\n", " cell_fw = rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)\n", " cell_bw = rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)\n", "\n", " outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x_data,\n", " sequence_length=[2, 3, 1],\n", " dtype=tf.float32)\n", "\n", " sess.run(tf.global_variables_initializer())\n", " pp.pprint(sess.run(outputs))\n", " pp.pprint(sess.run(states))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "array([[[ 0., 1., 2.],\n", " [ 3., 4., 5.],\n", " [ 6., 7., 8.],\n", " [ 9., 10., 11.],\n", " [ 12., 13., 14.]],\n", "\n", " [[ 15., 16., 17.],\n", " [ 18., 19., 20.],\n", " [ 21., 22., 23.],\n", " [ 24., 25., 26.],\n", " [ 27., 28., 29.]],\n", "\n", " [[ 30., 31., 32.],\n", " [ 33., 34., 35.],\n", " [ 36., 37., 38.],\n", " [ 39., 40., 41.],\n", " [ 42., 43., 44.]]], dtype=float32)\n", "array([[ 0., 1., 2.],\n", " [ 3., 4., 5.],\n", " [ 6., 7., 8.],\n", " [ 9., 10., 11.],\n", " [ 12., 13., 14.],\n", " [ 15., 16., 17.],\n", " [ 18., 19., 20.],\n", " [ 21., 22., 23.],\n", " [ 24., 25., 26.],\n", " [ 27., 28., 29.],\n", " [ 30., 31., 32.],\n", " [ 33., 34., 35.],\n", " [ 36., 37., 38.],\n", " [ 39., 40., 41.],\n", " [ 42., 43., 44.]], dtype=float32)\n", "array([[[ 25., 28., 31., 34., 37.],\n", " [ 70., 82., 94., 106., 118.],\n", " [ 115., 136., 157., 178., 199.],\n", " [ 160., 190., 220., 250., 280.],\n", " [ 205., 244., 283., 322., 361.]],\n", "\n", " [[ 250., 298., 346., 394., 442.],\n", " [ 295., 352., 409., 466., 523.],\n", " [ 340., 406., 472., 538., 604.],\n", " [ 385., 460., 535., 610., 685.],\n", " [ 430., 514., 598., 682., 766.]],\n", "\n", " [[ 475., 568., 661., 754., 847.],\n", " [ 520., 622., 724., 826., 928.],\n", " [ 565., 676., 787., 898., 1009.],\n", " [ 610., 730., 850., 970., 1090.],\n", " [ 655., 784., 913., 1042., 1171.]]], dtype=float32)\n" ] } ], "source": [ "# flattern based softmax\n", "hidden_size=3\n", "sequence_length=5\n", "batch_size=3\n", "num_classes=5\n", "\n", "pp.pprint(x_data) # hidden_size=3, sequence_length=4, batch_size=2\n", "x_data = x_data.reshape(-1, hidden_size)\n", "pp.pprint(x_data)\n", "\n", "softmax_w = np.arange(15, dtype=np.float32).reshape(hidden_size, num_classes)\n", "outputs = np.matmul(x_data, softmax_w)\n", "outputs = outputs.reshape(-1, sequence_length, num_classes) # batch, seq, class\n", "pp.pprint(outputs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss1: 0.313262 Loss2: 1.31326 Loss3: 0.646595\n" ] } ], "source": [ "# [batch_size, sequence_length, emb_dim ]\n", "prediction1 = tf.constant([[[0, 1], [0, 1], [0, 1]]], dtype=tf.float32)\n", "prediction2 = tf.constant([[[1, 0], [1, 0], [1, 0]]], dtype=tf.float32)\n", "prediction3 = tf.constant([[[0, 1], [1, 0], [0, 1]]], dtype=tf.float32)\n", "\n", "# [batch_size, sequence_length]\n", "y_data = tf.constant([[1, 1, 1]])\n", "\n", "# [batch_size * sequence_length]\n", "weights = tf.constant([[1, 1, 1]], dtype=tf.float32)\n", "\n", "sequence_loss1 = tf.contrib.seq2seq.sequence_loss(prediction1, y_data, weights)\n", "sequence_loss2 = tf.contrib.seq2seq.sequence_loss(prediction2, y_data, weights)\n", "sequence_loss3 = tf.contrib.seq2seq.sequence_loss(prediction3, y_data, weights)\n", "\n", "sess.run(tf.global_variables_initializer())\n", "print(\"Loss1: \", sequence_loss1.eval(),\n", " \"Loss2: \", sequence_loss2.eval(),\n", " \"Loss3: \", sequence_loss3.eval())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 0 }