{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai.io import *\n",
    "from fastai.conv_learner import *\n",
    "\n",
    "from fastai.column_data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We're going to download the collected works of Nietzsche to use as our data for this class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH='data/nietzsche/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "corpus length: 600893\n"
     ]
    }
   ],
   "source": [
    "get_data(\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\", f'{PATH}nietzsche.txt')\n",
    "text = open(f'{PATH}nietzsche.txt').read()\n",
    "print('corpus length:', len(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrible\\nseriousness and clumsy importunity with which they have usually paid\\ntheir addresses to Truth, have been unskilled and unseemly methods for\\nwinning a woman? Certainly she has never allowed herself '"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text[:400]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total chars: 85\n"
     ]
    }
   ],
   "source": [
    "chars = sorted(list(set(text)))\n",
    "vocab_size = len(chars)+1\n",
    "print('total chars:', vocab_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes it's useful to have a zero value in the dataset, e.g. for padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chars.insert(0, \"\\0\")\n",
    "\n",
    "''.join(chars[1:-6])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Map from chars to indices and back again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_indices = {c: i for i, c in enumerate(chars)}\n",
    "indices_char = {i: c for i, c in enumerate(chars)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*idx* will be the data we use from now on - it simply converts all the characters to their index (based on the mapping above)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = [char_indices[c] for c in text]\n",
    "\n",
    "idx[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not gro'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "''.join(indices_char[i] for i in idx[:70])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Three char model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a list of every 4th character, starting at the 0th, 1st, 2nd, then 3rd characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cs=3\n",
    "c1_dat = [idx[i]   for i in range(0, len(idx)-cs, cs)]\n",
    "c2_dat = [idx[i+1] for i in range(0, len(idx)-cs, cs)]\n",
    "c3_dat = [idx[i+2] for i in range(0, len(idx)-cs, cs)]\n",
    "c4_dat = [idx[i+3] for i in range(0, len(idx)-cs, cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = np.stack(c1_dat)\n",
    "x2 = np.stack(c2_dat)\n",
    "x3 = np.stack(c3_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.stack(c4_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first 4 inputs and outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([40, 30, 29,  1]), array([42, 25,  1, 43]), array([29, 27,  1, 45]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[:4], x2[:4], x3[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([30, 29,  1, 40])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200295,), (200295,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pick a size for our hidden state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_hidden = 256"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The number of latent factors to create (i.e. the size of the embedding matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_fac = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Char3Model(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "\n",
    "        # The 'green arrow' from our diagram - the layer operation from input to hidden\n",
    "        self.l_in = nn.Linear(n_fac, n_hidden)\n",
    "\n",
    "        # The 'orange arrow' from our diagram - the layer operation from hidden to hidden\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        \n",
    "        # The 'blue arrow' from our diagram - the layer operation from hidden to output\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, c1, c2, c3):\n",
    "        in1 = F.relu(self.l_in(self.e(c1)))\n",
    "        in2 = F.relu(self.l_in(self.e(c2)))\n",
    "        in3 = F.relu(self.l_in(self.e(c3)))\n",
    "        \n",
    "        h = V(torch.zeros(in1.size()).cuda())\n",
    "        h = F.tanh(self.l_hidden(h+in1))\n",
    "        h = F.tanh(self.l_hidden(h+in2))\n",
    "        h = F.tanh(self.l_hidden(h+in3))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', [-1], np.stack([x1,x2,x3], axis=1), y, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = Char3Model(vocab_size, n_fac).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs,yt = next(it)\n",
    "t = m(*V(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73483a3ac1804c3e81c8de6744d5c4bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       2.09627  6.52849]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7278ec0864e451795d91bac0ff944c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.84525  6.52312]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'T'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('y. ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('ppl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next(' th')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' '"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('and')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our first RNN!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the size of our unrolled RNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cs=8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each of 0 through 7, create a list of every 8th character with that starting point. These will be the 8 inputs to our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(len(idx)-cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then create a list of the next character in each of these series. This will be the labels for our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_out_dat = [idx[j+cs] for j in range(len(idx)-cs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.stack(c_in_dat, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(600884, 8)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.stack(c_out_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So each column below is one series of 8 characters from the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[40, 42, 29, 30, 25, 27, 29,  1],\n",
       "       [42, 29, 30, 25, 27, 29,  1,  1],\n",
       "       [29, 30, 25, 27, 29,  1,  1,  1],\n",
       "       [30, 25, 27, 29,  1,  1,  1, 43],\n",
       "       [25, 27, 29,  1,  1,  1, 43, 45],\n",
       "       [27, 29,  1,  1,  1, 43, 45, 40],\n",
       "       [29,  1,  1,  1, 43, 45, 40, 40],\n",
       "       [ 1,  1,  1, 43, 45, 40, 40, 39]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[:cs,:cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...and this is the next character after each sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1,  1, 43, 45, 40, 40, 39, 43])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[:cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_idx = get_cv_idxs(len(idx)-cs-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', val_idx, xs, y, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharLoopModel(nn.Module):\n",
    "    # This is an RNN!\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.l_in = nn.Linear(n_fac, n_hidden)\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(bs, n_hidden).cuda())\n",
    "        for c in cs:\n",
    "            inp = F.relu(self.l_in(self.e(c)))\n",
    "            h = F.tanh(self.l_hidden(h+inp))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharLoopModel(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d1c8fb012c74fe191921d467c80b5ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       2.02986  1.99268]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e4a151c0f274c129e346a22fd4bdece",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.73588  1.75103]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharLoopConcatModel(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.l_in = nn.Linear(n_fac+n_hidden, n_hidden)\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(bs, n_hidden).cuda())\n",
    "        for c in cs:\n",
    "            inp = torch.cat((h, self.e(c)), 1)\n",
    "            inp = F.relu(self.l_in(inp))\n",
    "            h = F.tanh(self.l_hidden(inp))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharLoopConcatModel(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs,yt = next(it)\n",
    "t = m(*V(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1b3572e787441d8b2e5d80317245596",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.81654  1.78501]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9aa67fbb4a2f42509dbe7753bc86d9a1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.69008  1.69936]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'t'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('part of ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'n'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('queens a')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN with pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(1, bs, n_hidden))\n",
    "        inp = self.e(torch.stack(cs))\n",
    "        outp,h = self.rnn(inp, h)\n",
    "        \n",
    "        return F.log_softmax(self.l_out(outp[-1]), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs,yt = next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 42])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m.e(V(torch.stack(xs)))\n",
    "t.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([8, 512, 256]), torch.Size([1, 512, 256]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ht = V(torch.zeros(1, 512,n_hidden))\n",
    "outp, hn = m.rnn(t, ht)\n",
    "outp.size(), hn.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([512, 85])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m(*V(xs)); t.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "497078d15ec348149442681039df2e50",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.86065  1.84255]                                 \n",
      "[ 1.       1.68014  1.67387]                                 \n",
      "[ 2.       1.58828  1.59169]                                 \n",
      "[ 3.       1.52989  1.54942]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65a2f7bedaa34de2a40296a07387c1c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.46841  1.50966]                                 \n",
      "[ 1.       1.46482  1.5039 ]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 2, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next_n(inp, n):\n",
    "    res = inp\n",
    "    for i in range(n):\n",
    "        c = get_next(inp)\n",
    "        res += c\n",
    "        inp = inp[1:]+c\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'for those the same the same the same the same th'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next_n('for thos', 40)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-output model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take non-overlapping sets of characters this time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(0, len(idx)-cs-1, cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then create the exact same thing, offset by 1, as our labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_out_dat = [[idx[i+j] for i in range(cs)] for j in range(1, len(idx)-cs, cs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(75111, 8)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs = np.stack(c_in_dat)\n",
    "xs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(75111, 8)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys = np.stack(c_out_dat)\n",
    "ys.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[40, 42, 29, 30, 25, 27, 29,  1],\n",
       "       [ 1,  1, 43, 45, 40, 40, 39, 43],\n",
       "       [33, 38, 31,  2, 73, 61, 54, 73],\n",
       "       [ 2, 44, 71, 74, 73, 61,  2, 62],\n",
       "       [72,  2, 54,  2, 76, 68, 66, 54],\n",
       "       [67,  9,  9, 76, 61, 54, 73,  2],\n",
       "       [73, 61, 58, 67, 24,  2, 33, 72],\n",
       "       [ 2, 73, 61, 58, 71, 58,  2, 67]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[:cs,:cs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[42, 29, 30, 25, 27, 29,  1,  1],\n",
       "       [ 1, 43, 45, 40, 40, 39, 43, 33],\n",
       "       [38, 31,  2, 73, 61, 54, 73,  2],\n",
       "       [44, 71, 74, 73, 61,  2, 62, 72],\n",
       "       [ 2, 54,  2, 76, 68, 66, 54, 67],\n",
       "       [ 9,  9, 76, 61, 54, 73,  2, 73],\n",
       "       [61, 58, 67, 24,  2, 33, 72,  2],\n",
       "       [73, 61, 58, 71, 58,  2, 67, 68]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys[:cs,:cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_idx = get_cv_idxs(len(xs)-cs-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', val_idx, xs, ys, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(1, bs, n_hidden))\n",
    "        inp = self.e(torch.stack(cs))\n",
    "        outp,h = self.rnn(inp, h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xst,yt = next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll_loss_seq(inp, targ):\n",
    "    sl,bs,nh = inp.size()\n",
    "    targ = targ.transpose(0,1).contiguous().view(-1)\n",
    "    return F.nll_loss(inp.view(-1,nh), targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "725ca331d28b482e9c7a4f83f741498e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       2.59241  2.40251]                                \n",
      "[ 1.       2.28474  2.19859]                                \n",
      "[ 2.       2.13883  2.08836]                                \n",
      "[ 3.       2.04892  2.01564]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "adb9aa22524d4bfd8b001d2efd10dbc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.99819  2.00106]                               \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 1, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Identity init!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "    1     0     0  ...      0     0     0\n",
       "    0     1     0  ...      0     0     0\n",
       "    0     0     1  ...      0     0     0\n",
       "       ...          ⋱          ...       \n",
       "    0     0     0  ...      1     0     0\n",
       "    0     0     0  ...      0     1     0\n",
       "    0     0     0  ...      0     0     1\n",
       "[torch.cuda.FloatTensor of size 256x256 (GPU 0)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.rnn.weight_hh_l0.data.copy_(torch.eye(n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e141251f24d4083a6e8b2fa15dea724",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       2.39428  2.21111]                                \n",
      "[ 1.       2.10381  2.03275]                                \n",
      "[ 2.       1.99451  1.96393]                               \n",
      "[ 3.       1.93492  1.91763]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ddf833e8b7ec4a3aa29dd271911f76ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.84035  1.85742]                                \n",
      "[ 1.       1.82896  1.84887]                                \n",
      "[ 2.       1.81879  1.84281]                               \n",
      "[ 3.       1.81337  1.83801]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stateful model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mmodels\u001b[0m/  nietzsche.txt  \u001b[01;34mtrn\u001b[0m/  \u001b[01;34mval\u001b[0m/\r\n"
     ]
    }
   ],
   "source": [
    "from torchtext import vocab, data\n",
    "\n",
    "from fastai.nlp import *\n",
    "from fastai.lm_rnn import *\n",
    "\n",
    "PATH='data/nietzsche/'\n",
    "\n",
    "TRN_PATH = 'trn/'\n",
    "VAL_PATH = 'val/'\n",
    "TRN = f'{PATH}{TRN_PATH}'\n",
    "VAL = f'{PATH}{VAL_PATH}'\n",
    "\n",
    "# Note: The student needs to practice her shell skills and prepare her own dataset before proceeding:\n",
    "# - trn/trn.txt (first 80% of nietzsche.txt)\n",
    "# - val/val.txt (last 20% of nietzsche.txt)\n",
    "\n",
    "%ls {PATH}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trn.txt\r\n"
     ]
    }
   ],
   "source": [
    "%ls {PATH}trn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(963, 56, 1, 493747)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT = data.Field(lower=True, tokenize=list)\n",
    "bs=64; bptt=8; n_fac=42; n_hidden=256\n",
    "\n",
    "FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)\n",
    "md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)\n",
    "\n",
    "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        self.vocab_size = vocab_size\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulRnn(md.nt, n_fac, 512).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a9e0a39ef174c72bac575be7e20579c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.81983  1.81247]                                 \n",
      "[ 1.       1.63097  1.66228]                                 \n",
      "[ 2.       1.54433  1.57824]                                 \n",
      "[ 3.       1.48563  1.54505]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b15bf8bcc7445e694dbcb3beb658b74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.4187   1.50374]                                 \n",
      "[ 1.       1.41492  1.49391]                                 \n",
      "[ 2.       1.41001  1.49339]                                 \n",
      "[ 3.       1.40756  1.486  ]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "set_lrs(opt, 1e-4)\n",
    "\n",
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RNN loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From the pytorch source\n",
    "\n",
    "def RNNCell(input, hidden, w_ih, w_hh, b_ih, b_hh):\n",
    "    return F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulRnn2(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        super().__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNNCell(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp = []\n",
    "        o = self.h\n",
    "        for c in cs: \n",
    "            o = self.rnn(self.e(c), o)\n",
    "            outp.append(o)\n",
    "        outp = self.l_out(torch.stack(outp))\n",
    "        self.h = repackage_var(o)\n",
    "        return F.log_softmax(outp, dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulRnn2(md.nt, n_fac, 512).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c46f24bfa194e1ba9d73e22283ca6af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.81013  1.7969 ]                                 \n",
      "[ 1.       1.62515  1.65346]                                 \n",
      "[ 2.       1.53913  1.58065]                                 \n",
      "[ 3.       1.48698  1.54217]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulGRU(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        super().__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.GRU(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# From the pytorch source code - for reference\n",
    "\n",
    "def GRUCell(input, hidden, w_ih, w_hh, b_ih, b_hh):\n",
    "    gi = F.linear(input, w_ih, b_ih)\n",
    "    gh = F.linear(hidden, w_hh, b_hh)\n",
    "    i_r, i_i, i_n = gi.chunk(3, 1)\n",
    "    h_r, h_i, h_n = gh.chunk(3, 1)\n",
    "\n",
    "    resetgate = F.sigmoid(i_r + h_r)\n",
    "    inputgate = F.sigmoid(i_i + h_i)\n",
    "    newgate = F.tanh(i_n + resetgate * h_n)\n",
    "    return newgate + inputgate * (hidden - newgate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulGRU(md.nt, n_fac, 512).cuda()\n",
    "\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e518384d71c345a8b145b35d4ee894fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.68409  1.67784]                                 \n",
      "[ 1.       1.49813  1.52661]                                 \n",
      "[ 2.       1.41674  1.46769]                                 \n",
      "[ 3.       1.36359  1.43818]                                 \n",
      "[ 4.       1.33223  1.41777]                                 \n",
      "[ 5.       1.30217  1.40511]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 6, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be385370c27f4b788920caf48f90aeea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.22708  1.36926]                                 \n",
      "[ 1.       1.21948  1.3696 ]                                 \n",
      "[ 2.       1.22541  1.36969]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 3, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Putting it all together: LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import sgdr\n",
    "\n",
    "n_hidden=512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulLSTM(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs, nl):\n",
    "        super().__init__()\n",
    "        self.vocab_size,self.nl = vocab_size,nl\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h[0].size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs):\n",
    "        self.h = (V(torch.zeros(self.nl, bs, n_hidden)),\n",
    "                  V(torch.zeros(self.nl, bs, n_hidden)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulLSTM(md.nt, n_fac, 512, 2).cuda()\n",
    "lo = LayerOptimizer(optim.Adam, m, 1e-2, 1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'{PATH}models', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6943ca600bbf4a49a0020b2467c2ddb8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.72032  1.64016]                                 \n",
      "[ 1.       1.62891  1.58176]                                 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(m, md, 2, lo.opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "765d0d78da6647d48276a638f70aeec9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.47969  1.4472 ]                                 \n",
      "[ 1.       1.51411  1.46612]                                 \n",
      "[ 2.       1.412    1.39909]                                 \n",
      "[ 3.       1.53689  1.48337]                                 \n",
      "[ 4.       1.47375  1.43169]                                 \n",
      "[ 5.       1.39828  1.37963]                                 \n",
      "[ 6.       1.34546  1.35795]                                 \n",
      "[ 7.       1.51999  1.47165]                                 \n",
      "[ 8.       1.48992  1.46146]                                 \n",
      "[ 9.       1.45492  1.42829]                                 \n",
      "[ 10.        1.42027   1.39028]                              \n",
      "[ 11.        1.3814    1.36539]                              \n",
      "[ 12.        1.33895   1.34178]                              \n",
      "[ 13.        1.30737   1.32871]                              \n",
      "[ 14.        1.28244   1.31518]                              \n",
      "\n"
     ]
    }
   ],
   "source": [
    "on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')\n",
    "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]\n",
    "fit(m, md, 2**4-1, lo.opt, F.nll_loss, callbacks=cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4394818ec37f4b419397628b7cc8b815",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       1.46053  1.43462]                                 \n",
      "[ 1.       1.51537  1.47747]                                 \n",
      "[ 2.       1.39208  1.38293]                                 \n",
      "[ 3.       1.53056  1.49371]                                 \n",
      "[ 4.       1.46812  1.43389]                                 \n",
      "[ 5.       1.37624  1.37523]                                 \n",
      "[ 6.       1.3173   1.34022]                                 \n",
      "[ 7.       1.51783  1.47554]                                 \n",
      "[ 8.       1.4921   1.45785]                                 \n",
      "[ 9.       1.44843  1.42215]                                 \n",
      "[ 10.        1.40948   1.40858]                              \n",
      "[ 11.        1.37098   1.36648]                              \n",
      "[ 12.        1.32255   1.33842]                              \n",
      "[ 13.        1.28243   1.31106]                              \n",
      "[ 14.        1.25031   1.2918 ]                              \n",
      "[ 15.        1.49236   1.45316]                              \n",
      "[ 16.        1.46041   1.43622]                              \n",
      "[ 17.        1.45043   1.4498 ]                              \n",
      "[ 18.        1.43331   1.41297]                              \n",
      "[ 19.        1.43841   1.41704]                              \n",
      "[ 20.        1.41536   1.40521]                              \n",
      "[ 21.        1.39829   1.37656]                              \n",
      "[ 22.        1.37001   1.36891]                              \n",
      "[ 23.        1.35469   1.35909]                              \n",
      "[ 24.        1.32202   1.34228]                              \n",
      "[ 25.        1.29972   1.32256]                              \n",
      "[ 26.        1.28007   1.30903]                              \n",
      "[ 27.        1.24503   1.29125]                              \n",
      "[ 28.        1.22261   1.28316]                              \n",
      "[ 29.        1.20563   1.27397]                              \n",
      "[ 30.        1.18764   1.27178]                              \n",
      "[ 31.        1.18114   1.26694]                              \n",
      "[ 32.        1.44344   1.42405]                              \n",
      "[ 33.        1.43344   1.41616]                              \n",
      "[ 34.        1.4346    1.40442]                              \n",
      "[ 35.        1.42152   1.41359]                              \n",
      "[ 36.        1.42072   1.40835]                              \n",
      "[ 37.        1.41732   1.40498]                              \n",
      "[ 38.        1.41268   1.395  ]                              \n",
      "[ 39.        1.40725   1.39433]                              \n",
      "[ 40.        1.40181   1.39864]                              \n",
      "[ 41.        1.38621   1.37549]                              \n",
      "[ 42.        1.3838    1.38587]                              \n",
      "[ 43.        1.37644   1.37118]                              \n",
      "[ 44.        1.36287   1.36211]                              \n",
      "[ 45.        1.35942   1.36145]                              \n",
      "[ 46.        1.34712   1.34924]                              \n",
      "[ 47.        1.32994   1.34884]                              \n",
      "[ 48.        1.32788   1.33387]                              \n",
      "[ 49.        1.31553   1.342  ]                              \n",
      "[ 50.        1.30088   1.32435]                              \n",
      "[ 51.        1.28446   1.31166]                              \n",
      "[ 52.        1.27058   1.30807]                              \n",
      "[ 53.        1.26271   1.29935]                              \n",
      "[ 54.        1.24351   1.28942]                              \n",
      "[ 55.        1.23119   1.2838 ]                              \n",
      "[ 56.        1.2086    1.28364]                              \n",
      "[ 57.        1.19742   1.27375]                              \n",
      "[ 58.        1.18127   1.26758]                              \n",
      "[ 59.        1.17475   1.26858]                              \n",
      "[ 60.        1.15349   1.25999]                              \n",
      "[ 61.        1.14718   1.25779]                              \n",
      "[ 62.        1.13174   1.2524 ]                              \n",
      "\n"
     ]
    }
   ],
   "source": [
    "on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')\n",
    "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]\n",
    "fit(m, md, 2**6-1, lo.opt, F.nll_loss, callbacks=cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = TEXT.numericalize(inp)\n",
    "    p = m(VV(idxs.transpose(0,1)))\n",
    "    r = torch.multinomial(p[-1].exp(), 1)\n",
    "    return TEXT.vocab.itos[to_np(r)[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next_n(inp, n):\n",
    "    res = inp\n",
    "    for i in range(n):\n",
    "        c = get_next(inp)\n",
    "        res += c\n",
    "        inp = inp[1:]+c\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for those the skemps), or\n",
      "imaginates, though they deceives. it should so each ourselvess and new\n",
      "present, step absolutely for the\n",
      "science.\" the contradity and\n",
      "measuring, \n",
      "the whole!\n",
      "\n",
      "293. perhaps, that every life a values of blood\n",
      "of\n",
      "intercourse when it senses there is unscrupulus, his very rights, and still impulse, love?\n",
      "just after that thereby how made with the way anything, and set for harmless philos\n"
     ]
    }
   ],
   "source": [
    "print(get_next_n('for thos', 400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}