{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Knet RNN example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# After installing and starting Julia run the following to install the required packages:\n",
    "# Pkg.init(); Pkg.update()\n",
    "# for p in (\"CUDAdrv\",\"IJulia\",\"PyCall\",\"JLD2\",\"Knet\"); Pkg.add(p); end\n",
    "# Pkg.checkout(\"Knet\",\"ilkarman\") # make sure we have the right Knet version\n",
    "# Pkg.build(\"Knet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "using Knet\n",
    "True=true # so we can read the python params\n",
    "include(\"common/params_lstm.py\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OS: Linux\n",
      "Julia: 0.6.1\n",
      "Knet: 0.8.5+\n",
      "GPU: Tesla K80\n",
      "\n"
     ]
    }
   ],
   "source": [
    "println(\"OS: \", Sys.KERNEL)\n",
    "println(\"Julia: \", VERSION)\n",
    "println(\"Knet: \", Pkg.installed(\"Knet\"))\n",
    "println(\"GPU: \", readstring(`nvidia-smi --query-gpu=name --format=csv,noheader`))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# define model\n",
    "function initmodel()\n",
    "    rnnSpec,rnnWeights = rnninit(EMBEDSIZE,NUMHIDDEN; rnnType=:gru)\n",
    "    inputMatrix = KnetArray(xavier(Float32,EMBEDSIZE,MAXFEATURES))\n",
    "    outputMatrix = KnetArray(xavier(Float32,2,NUMHIDDEN))\n",
    "    return rnnSpec,(rnnWeights,inputMatrix,outputMatrix)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# define loss and its gradient\n",
    "function predict(weights, inputs, rnnSpec)\n",
    "    rnnWeights, inputMatrix, outputMatrix = weights # (1,1,W), (X,V), (2,H)\n",
    "    indices = hcat(inputs...)' # (B,T)\n",
    "    rnnInput = inputMatrix[:,indices] # (X,B,T)\n",
    "    rnnOutput = rnnforw(rnnSpec, rnnWeights, rnnInput)[1] # (H,B,T)\n",
    "    return outputMatrix * rnnOutput[:,:,end] # (2,H) * (H,B) = (2,B)\n",
    "end\n",
    "\n",
    "loss(w,x,y,r)=nll(predict(w,x,r),y)\n",
    "lossgradient = grad(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mLoading IMDB...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 10.266185 seconds (15.94 M allocations: 835.780 MiB, 3.98% gc time)\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "include(Knet.dir(\"data\",\"imdb.jl\"))\n",
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=MAXFEATURES)\n",
    "for d in (xtrn,ytrn,xtst,ytst); println(summary(d)); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; knetgc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 14.319533 seconds (2.08 M allocations: 138.579 MiB, 3.58% gc time)\n"
     ]
    }
   ],
   "source": [
    "# cold start\n",
    "@time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "    grads = lossgradient(weights,x,y,rnnSpec)\n",
    "    update!(weights, grads, optim)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# prepare for training\n",
    "weights = nothing; knetgc(); # Reclaim memory from previous run\n",
    "rnnSpec,weights = initmodel()\n",
    "optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTraining...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  9.776101 seconds (356.68 k allocations: 45.007 MiB, 4.79% gc time)\n",
      "  9.786896 seconds (352.22 k allocations: 44.658 MiB, 5.91% gc time)\n",
      "  9.732747 seconds (352.94 k allocations: 44.669 MiB, 5.92% gc time)\n",
      " 29.298876 seconds (1.07 M allocations: 134.572 MiB, 5.54% gc time)\n"
     ]
    }
   ],
   "source": [
    "# 29s\n",
    "info(\"Training...\")\n",
    "@time for epoch in 1:EPOCHS\n",
    "    @time for (x,y) in minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "        grads = lossgradient(weights,x,y,rnnSpec)\n",
    "        update!(weights, grads, optim)\n",
    "    end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mINFO: \u001b[39m\u001b[22m\u001b[36mTesting...\n",
      "\u001b[39m"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  2.999301 seconds (70.50 k allocations: 34.680 MiB, 11.61% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.844511217948718"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info(\"Testing...\")\n",
    "@time accuracy(weights, minibatch(xtst,ytst,BATCHSIZE), (w,x)->predict(w,x,rnnSpec))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.6.1",
   "language": "julia",
   "name": "julia-0.6"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}