{
"cells": [
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false,
"focus": false,
"id": "ccee6335-b1ae-4254-8234-72e9ab580983"
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import data\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import plotly\n",
"from plotly.graph_objs import Scatter, Layout\n",
"import plotly.graph_objs as go\n",
"plotly.offline.init_notebook_mode()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"focus": true,
"id": "e5ae9566-1999-4c39-ae29-bd669864d168"
},
"outputs": [],
"source": [
"idd, seq = data.readseq('train.csv')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false,
"focus": false,
"id": "07e6bae7-28af-4e38-9f71-175cb7afddaf"
},
"outputs": [],
"source": [
"def get_batch(s, ix):\n",
" t = s[ix]\n",
" return t[:-1], t[1:]\n",
"def sine_data(ix,size=50):\n",
" x = np.arange(ix, ix + size, step = 0.2)\n",
" y = np.cos(x)\n",
" return y[:-1], y[1:]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true,
"focus": false,
"id": "ad0fa101-aab3-4a3b-a0f8-3efd88227223"
},
"outputs": [],
"source": [
"inp_out_size = 1\n",
"hidden_layer_size = 20\n",
"lr = 0.01\n",
"epoch = 1600\n",
"print_step = 200\n",
"num_steps = 22"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true,
"focus": false,
"id": "27113ea9-a6cb-4a3f-ab27-d0319d5209a1"
},
"outputs": [],
"source": [
"# zoneout as in https://arxiv.org/pdf/1606.01305v1.pdf\n",
"def zoneout(h, h_prev):\n",
" assert h.get_shape() == h_prev.get_shape()\n",
" r = tf.select(tf.random_uniform(h.get_shape()) > keep, tf.ones_like(h), tf.zeros_like(h))\n",
" h_z = tf.mul(r, h_prev) + tf.mul(tf.sub(tf.ones_like(r), r), h)\n",
" return h_z"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
"focus": false,
"id": "fd90b47c-fd3f-4f17-842a-36c27825c509"
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"initializer = tf.random_uniform_initializer(minval=-0.01, maxval=0.01, dtype=tf.float32)\n",
"Wxh = tf.get_variable('Wxh', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n",
"Whh = tf.get_variable('Whh', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)\n",
"Why = tf.get_variable('Why',shape=[hidden_layer_size, inp_out_size], initializer=initializer)\n",
"# weights associated with update gate\n",
"Wxz = tf.get_variable('Wxz', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n",
"Whz = tf.get_variable('Whz', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)\n",
"# weights associated with the reset gate\n",
"Wxr = tf.get_variable('Wxr', shape=[inp_out_size, hidden_layer_size], initializer=initializer)\n",
"Whr = tf.get_variable('Whr', shape=[hidden_layer_size, hidden_layer_size], initializer=initializer)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true,
"focus": false,
"id": "cba76c8f-07a7-4231-955f-56c33c22c94e"
},
"outputs": [],
"source": [
"def GRU(prev, inp):\n",
" i = tf.reshape(inp, shape=[1, -1])\n",
" p = tf.reshape(prev, shape=[1, -1])\n",
" z = tf.nn.sigmoid(tf.matmul(i, Wxz) + tf.matmul(p, Whz)) # update gate\n",
" r = tf.nn.sigmoid(tf.matmul(i, Wxr) + tf.matmul(p, Whr)) # reset gate\n",
" h_ = tf.nn.tanh(tf.matmul(i, Wxh) + tf.matmul(tf.mul(p, r), Whh))\n",
" h = tf.mul(tf.sub(tf.ones_like(z), z), h_) + tf.mul(z, p)\n",
" h = zoneout(h, p)\n",
" return tf.reshape(h, [hidden_layer_size])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false,
"focus": false,
"id": "e63ef0ba-c142-4a38-9a29-024296f9e222"
},
"outputs": [],
"source": [
"# model\n",
"inputs = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n",
"targets = tf.placeholder(shape=[None, 1], dtype=tf.float32)\n",
"keep = tf.placeholder(dtype=tf.float32)\n",
"initial = tf.placeholder(shape=[hidden_layer_size], dtype=tf.float32)\n",
"hiddens = tf.scan(GRU, inputs, initializer=initial)\n",
"outputs = tf.matmul(hiddens, Why)\n",
"loss = tf.sqrt(tf.reduce_sum(tf.square(tf.sub(outputs, targets))))\n",
"optimizer = tf.train.AdagradOptimizer(lr)\n",
"grad = optimizer.compute_gradients(loss, [Wxh])[0]\n",
"optimize_op = optimizer.minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false,
"focus": false,
"id": "48d261de-2e79-4276-a375-1bc645d7b5b1"
},
"outputs": [],
"source": [
"sess = tf.Session()\n",
"sess.run(tf.initialize_all_variables())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false,
"focus": false,
"id": "f28d73e3-9f0b-4da6-9d3c-42ae13bbea6a"
},
"outputs": [],
"source": [
"ix = 0\n",
"ini = np.zeros([hidden_layer_size])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false,
"focus": false,
"id": "4a0c5b1a-9be7-4e06-bfe5-e13207e76c32"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 Loss 7.34901\n",
"epoch 200 Loss 2.08433\n",
"epoch 400 Loss 0.335419\n",
"epoch 600 Loss 0.352606\n",
"epoch 800 Loss 0.471073\n",
"epoch 1000 Loss 0.546137\n",
"epoch 1200 Loss 0.374263\n",
"epoch 1400 Loss 0.116386\n"
]
}
],
"source": [
"for i in range(epoch):\n",
" a, b = sine_data(ix, size=num_steps) \n",
" a = np.reshape(a, [-1, 1])\n",
" b = np.reshape(b, [-1, 1])\n",
" feed = {inputs: a, targets: b, initial: ini, keep: 1.0}\n",
" l, h, _ = sess.run([loss, hiddens, optimize_op], feed_dict=feed)\n",
" ix += 1\n",
" if i % print_step == 0:\n",
" print('epoch', i, 'Loss', l)\n",
" ix += num_steps\n",
" ini = h[-1]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false,
"focus": false,
"id": "36d0262b-4e81-4572-aae9-a200f86e868b"
},
"outputs": [],
"source": [
"def generate(sess, n):\n",
" global ix\n",
" print(ix)\n",
" ixx = [[ix]]\n",
" v = []\n",
" h = np.zeros(hidden_layer_size)\n",
" for i in range(n):\n",
" o, h = sess.run([outputs, hiddens], {inputs:ixx, initial: h, keep: 1.0})\n",
" h = h.reshape(hidden_layer_size)\n",
" ixx = o\n",
" v.append(np.squeeze(o))\n",
" return v"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false,
"focus": false,
"id": "d9e6708a-8523-4399-80fc-c3bb0a011ad1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36800\n"
]
}
],
"source": [
"print(ix)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false,
"focus": false,
"id": "4670ed87-3bfd-4a9b-b57b-36cc0148eb21"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36800\n",
"0.31068776191\n"
]
}
],
"source": [
"x = np.arange(ix, ix + 50, step = 0.2)\n",
"pred = np.array(generate(sess, len(x)))\n",
"true = np.cos(x)\n",
"test_loss = np.sqrt(np.mean((true - pred) ** 2))\n",
"print(test_loss)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot\n",
"trace0 = go.Scatter(\n",
" x = x,\n",
" y = pred,\n",
" mode = 'lines',\n",
" name = 'predited'\n",
")\n",
"trace1 = go.Scatter(\n",
" x = x,\n",
" y = true,\n",
" mode = 'lines',\n",
" name = 'true'\n",
")\n",
"layout = go.Layout(\n",
" title=\"cos : true vs predicted\",\n",
" xaxis=dict(\n",
" range=[x[0], x[-1]]\n",
" ),\n",
" yaxis=dict(\n",
" range=[min(true), max(true)]\n",
" )\n",
")\n",
"data = [trace0, trace1]\n",
"fig = go.Figure(data=data, layout=layout)\n",
"plotly.offline.iplot(fig, filename='cos') "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"focus": false,
"id": "546b33a1-9815-431e-81c6-d3c805c770da"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1+"
}
},
"nbformat": 4,
"nbformat_minor": 0
}