{ "cells": [ { "cell_type": "code", "execution_count": 41, "metadata": { "collapsed": false, "focus": false, "id": "ccee6335-b1ae-4254-8234-72e9ab580983" }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import data\n", "import tensorflow as tf\n", "import numpy as np\n", "import plotly\n", "from plotly.graph_objs import Scatter, Layout\n", "import plotly.graph_objs as go\n", "plotly.offline.init_notebook_mode()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "focus": true, "id": "e5ae9566-1999-4c39-ae29-bd669864d168" }, "outputs": [], "source": [ "idd, seq = data.readseq('train.csv')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "focus": false, "id": "07e6bae7-28af-4e38-9f71-175cb7afddaf" }, "outputs": [], "source": [ "def get_batch(s, ix):\n", " t = s[ix]\n", " return t[:-1], t[1:]\n", "def sine_data(ix,size=50):\n", " x = np.arange(ix, ix + size, step = 0.2)\n", " y = np.cos(x)\n", " return y[:-1], y[1:]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true, "focus": false, "id": "ad0fa101-aab3-4a3b-a0f8-3efd88227223" }, "outputs": [], "source": [ "inp_out_size = 1\n", "hidden_layer_size = 20\n", "lr = 0.01\n", "epoch = 1600\n", "print_step = 200\n", "num_steps = 22" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true, "focus": false, "id": "27113ea9-a6cb-4a3f-ab27-d0319d5209a1" }, "outputs": [], "source": [ "# zoneout as in https://arxiv.org/pdf/1606.01305v1.pdf\n", "def zoneout(h, h_prev):\n", " assert h.get_shape() == h_prev.get_shape()\n", " r = tf.select(tf.random_uniform(h.get_shape()) > keep, tf.ones_like(h), tf.zeros_like(h))\n", " h_z = tf.mul(r, h_prev) + tf.mul(tf.sub(tf.ones_like(r), r), h)\n", " return h_z" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "focus": false, "id": "fd90b47c-fd3f-4f17-842a-36c27825c509" }, "outputs": [], "source": [ "tf.reset_default_graph()\n", "initializer = tf.random_uniform_initializer(minval=-0.01, maxval=0.01, dtype=tf.float32)\n", "Wxh = tf.get_variable('Wxh', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n", "Whh = tf.get_variable('Whh', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)\n", "Why = tf.get_variable('Why',shape=[hidden_layer_size, inp_out_size], initializer=initializer)\n", "# weights associated with update gate\n", "Wxz = tf.get_variable('Wxz', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n", "Whz = tf.get_variable('Whz', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)\n", "# weights associated with the reset gate\n", "Wxr = tf.get_variable('Wxr', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n", "Whr = tf.get_variable('Whr', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true, "focus": false, "id": "cba76c8f-07a7-4231-955f-56c33c22c94e" }, "outputs": [], "source": [ "def GRU(prev, inp):\n", " i = tf.reshape(inp, shape=[1, -1])\n", " p = tf.reshape(prev, shape=[1, -1])\n", " z = tf.nn.sigmoid(tf.matmul(i, Wxz) + tf.matmul(p, Whz)) # update gate\n", " r = tf.nn.sigmoid(tf.matmul(i, Wxr) + tf.matmul(p, Whr)) # reset gate\n", " h_ = tf.nn.tanh(tf.matmul(i, Wxh) + tf.matmul(tf.mul(p, r), Whh))\n", " h = tf.mul(tf.sub(tf.ones_like(z), z), h_) + tf.mul(z, p)\n", " h = zoneout(h, p)\n", " return tf.reshape(h, [hidden_layer_size])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "focus": false, "id": "e63ef0ba-c142-4a38-9a29-024296f9e222" }, "outputs": [], "source": [ "# model\n", "inputs = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n", "targets = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n", "keep = tf.placeholder(dtype=tf.float32)\n", "initial = tf.placeholder(shape=[hidden_layer_size], dtype=tf.float32)\n", "hiddens = tf.scan(GRU, inputs, initializer=initial)\n", "outputs = tf.matmul(hiddens, Why)\n", "loss = tf.sqrt(tf.reduce_sum(tf.square(tf.sub(outputs, targets))))\n", "optimizer = tf.train.AdagradOptimizer(lr)\n", "grad = optimizer.compute_gradients(loss, [Wxh])[0]\n", "optimize_op = optimizer.minimize(loss)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false, "focus": false, "id": "48d261de-2e79-4276-a375-1bc645d7b5b1" }, "outputs": [], "source": [ "sess = tf.Session()\n", "sess.run(tf.initialize_all_variables())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "focus": false, "id": "f28d73e3-9f0b-4da6-9d3c-42ae13bbea6a" }, "outputs": [], "source": [ "ix = 0\n", "ini = np.zeros([hidden_layer_size])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false, "focus": false, "id": "4a0c5b1a-9be7-4e06-bfe5-e13207e76c32" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 Loss 7.34901\n", "epoch 200 Loss 2.08433\n", "epoch 400 Loss 0.335419\n", "epoch 600 Loss 0.352606\n", "epoch 800 Loss 0.471073\n", "epoch 1000 Loss 0.546137\n", "epoch 1200 Loss 0.374263\n", "epoch 1400 Loss 0.116386\n" ] } ], "source": [ "for i in range(epoch):\n", " a, b = sine_data(ix, size=num_steps) \n", " a = np.reshape(a, [-1, 1])\n", " b = np.reshape(b, [-1, 1])\n", " feed = {inputs: a, targets: b, initial: ini, keep: 1.0}\n", " l, h, _ = sess.run([loss, hiddens, optimize_op], feed_dict=feed)\n", " ix += 1\n", " if i % print_step == 0:\n", " print('epoch', i, 'Loss', l)\n", " ix += num_steps\n", " ini = h[-1]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false, "focus": false, "id": "36d0262b-4e81-4572-aae9-a200f86e868b" }, "outputs": [], "source": [ "def generate(sess, n):\n", " global ix\n", " print(ix)\n", " ixx = [[ix]]\n", " v = []\n", " h = np.zeros(hidden_layer_size)\n", " for i in range(n):\n", " o, h = sess.run([outputs, hiddens], {inputs:ixx, initial: h, keep: 1.0})\n", " h = h.reshape(hidden_layer_size)\n", " ixx = o\n", " v.append(np.squeeze(o))\n", " return v" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false, "focus": false, "id": "d9e6708a-8523-4399-80fc-c3bb0a011ad1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "36800\n" ] } ], "source": [ "print(ix)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": false, "focus": false, "id": "4670ed87-3bfd-4a9b-b57b-36cc0148eb21" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "36800\n", "0.31068776191\n" ] } ], "source": [ "x = np.arange(ix, ix + 50, step = 0.2)\n", "pred = np.array(generate(sess, len(x)))\n", "true = np.cos(x)\n", "test_loss = np.sqrt(np.mean((true - pred) ** 2))\n", "print(test_loss)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot\n", "trace0 = go.Scatter(\n", " x = x,\n", " y = pred,\n", " mode = 'lines',\n", " name = 'predited'\n", ")\n", "trace1 = go.Scatter(\n", " x = x,\n", " y = true,\n", " mode = 'lines',\n", " name = 'true'\n", ")\n", "layout = go.Layout(\n", " title=\"cos : true vs predicted\",\n", " xaxis=dict(\n", " range=[x[0], x[-1]]\n", " ),\n", " yaxis=dict(\n", " range=[min(true), max(true)]\n", " )\n", ")\n", "data = [trace0, trace1]\n", "fig = go.Figure(data=data, layout=layout)\n", "plotly.offline.iplot(fig, filename='cos') " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "focus": false, "id": "546b33a1-9815-431e-81c6-d3c805c770da" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1+" } }, "nbformat": 4, "nbformat_minor": 0 }