{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-level RNN MXNet Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import mxnet as mx\n",
    "from common.params_lstm import *\n",
    "from common.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS:  linux\n",
      "Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) \n",
      "[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]\n",
      "Numpy:  1.13.3\n",
      "MXNet:  0.12.1\n",
      "GPU:  ['Tesla K80']\n"
     ]
    }
   ],
   "source": [
    "print(\"OS: \", sys.platform)\n",
    "print(\"Python: \", sys.version)\n",
    "print(\"Numpy: \", np.__version__)\n",
    "print(\"MXNet: \", mx.__version__)\n",
    "print(\"GPU: \", get_gpu_name())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_symbol(CUDNN=True):\n",
    "    # https://mxnet.incubator.apache.org/api/python/rnn.html\n",
    "    data = mx.symbol.Variable('data')\n",
    "    embedded_step = mx.symbol.Embedding(data=data, input_dim=MAXFEATURES, output_dim=EMBEDSIZE)\n",
    "    \n",
    "    # Fusing RNN layers across time step into one kernel\n",
    "    # Improves speed but is less flexible\n",
    "    # Currently only supported if using cuDNN on GPU\n",
    "    if not CUDNN:\n",
    "        gru_cell = mx.rnn.GRUCell(num_hidden=NUMHIDDEN)\n",
    "    else:\n",
    "        gru_cell = mx.rnn.FusedRNNCell(num_hidden=NUMHIDDEN, num_layers=1, mode='gru')\n",
    "    \n",
    "    begin_state = gru_cell.begin_state()\n",
    "    # Call the cell to get the output of one time step for a batch.\n",
    "    # TODO: TNC layout (sequence length, batch size, and feature dimensions) is faster for RNN\n",
    "    outputs, states = gru_cell.unroll(length=MAXLEN, inputs=embedded_step, merge_outputs=False)\n",
    "    \n",
    "    fc1 = mx.symbol.FullyConnected(data=outputs[-1], num_hidden=2) \n",
    "    input_y = mx.symbol.Variable('softmax_label')  \n",
    "    m = mx.symbol.SoftmaxOutput(data=fc1, label=input_y, name=\"softmax\")\n",
    "    return m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model(m):\n",
    "    if GPU:\n",
    "        ctx = [mx.gpu(0)]\n",
    "    else:\n",
    "        ctx = mx.cpu()\n",
    "    mod = mx.mod.Module(context=ctx, symbol=m)\n",
    "    mod.bind(data_shapes=[('data', (BATCHSIZE, MAXLEN))],\n",
    "             label_shapes=[('softmax_label', (BATCHSIZE, ))])\n",
    "    # Glorot-uniform initializer\n",
    "    mod.init_params(initializer=mx.init.Xavier(rnd_type='uniform'))\n",
    "    mod.init_optimizer(optimizer='Adam', \n",
    "                       optimizer_params=(('learning_rate', LR),\n",
    "                                         ('beta1', BETA_1),\n",
    "                                         ('beta2', BETA_2),\n",
    "                                         ('epsilon', EPS)))\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing train set...\n",
      "Preparing test set...\n",
      "Trimming to 30000 max-features\n",
      "Padding to length 150\n",
      "(25000, 150) (25000, 150) (25000,) (25000,)\n",
      "int32 int32 int32 int32\n",
      "CPU times: user 5.48 s, sys: 313 ms, total: 5.8 s\n",
      "Wall time: 5.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Data into format for library\n",
    "x_train, x_test, y_train, y_test = imdb_for_library(seq_len=MAXLEN, max_features=MAXFEATURES)\n",
    "\n",
    "# Use custom iterator instead of mx.io.NDArrayIter() for consistency\n",
    "# Wrap as DataBatch class\n",
    "wrapper_db = lambda args: mx.io.DataBatch(data=[mx.nd.array(args[0])], label=[mx.nd.array(args[1])])\n",
    "\n",
    "print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)\n",
    "print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 81.3 ms, sys: 3.35 ms, total: 84.6 ms\n",
      "Wall time: 84.1 ms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda/envs/py35/lib/python3.5/site-packages/mxnet/rnn/rnn_cell.py:675: UserWarning: NTC layout detected. Consider using TNC for FusedRNNCell for faster speed\n",
      "  warnings.warn(\"NTC layout detected. Consider using \"\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Load symbol\n",
    "sym = create_symbol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 789 ms, sys: 359 ms, total: 1.15 s\n",
      "Wall time: 1.17 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# Initialise model\n",
    "model = init_model(sym)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Training ('accuracy', 0.77880608974358978)\n",
      "Epoch 1, Training ('accuracy', 0.92219551282051282)\n",
      "Epoch 2, Training ('accuracy', 0.96546474358974355)\n",
      "CPU times: user 25.9 s, sys: 4.29 s, total: 30.2 s\n",
      "Wall time: 28.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# 29s\n",
    "# Train and log accuracy\n",
    "metric = mx.metric.create('acc')\n",
    "for j in range(EPOCHS):\n",
    "    #train_iter.reset()\n",
    "    metric.reset()\n",
    "    #for batch in train_iter:\n",
    "    for batch in map(wrapper_db, yield_mb(x_train, y_train, BATCHSIZE, shuffle=True)):\n",
    "        model.forward(batch, is_train=True) \n",
    "        model.update_metric(metric, batch.label)\n",
    "        model.backward()              \n",
    "        model.update()\n",
    "    print('Epoch %d, Training %s' % (j, metric.get()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.57 s, sys: 350 ms, total: 2.92 s\n",
      "Wall time: 2.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "y_guess = model.predict(mx.io.NDArrayIter(x_test, batch_size=BATCHSIZE, shuffle=False))\n",
    "y_guess = np.argmax(y_guess.asnumpy(), axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:  0.85924\n"
     ]
    }
   ],
   "source": [
    "print(\"Accuracy: \", sum(y_guess == y_test)/len(y_guess))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}