{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai.imports import *\n",
    "from fastai.torch_imports import *\n",
    "from fastai.core import *\n",
    "from fastai.model import fit\n",
    "from fastai.dataset import *\n",
    "\n",
    "import torchtext\n",
    "from torchtext import vocab, data\n",
    "from torchtext.datasets import language_modeling\n",
    "\n",
    "from fastai.rnn_reg import *\n",
    "from fastai.rnn_train import *\n",
    "from fastai.nlp import *\n",
    "from fastai.lm_rnn import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mmodels\u001b[0m/  \u001b[01;34mtmp\u001b[0m/  wiki.test.tokens  wiki.train.tokens  wiki.valid.tokens\r\n"
     ]
    }
   ],
   "source": [
    "PATH='data/wikitext-2/'\n",
    "%ls {PATH}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " \r\n",
      " = Valkyria Chronicles III = \r\n",
      " \r\n",
      " Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the \" Nameless \" , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" <unk> Raven \" . \r\n",
      " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for series newcomers . Character designer <unk> Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \r\n"
     ]
    }
   ],
   "source": [
    "!head -5 {PATH}wiki.train.tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   36718  2051910 10797148 data/wikitext-2/wiki.train.tokens\r\n"
     ]
    }
   ],
   "source": [
    "!wc -lwc {PATH}wiki.train.tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   3760  213886 1121681 data/wikitext-2/wiki.valid.tokens\r\n"
     ]
    }
   ],
   "source": [
    "!wc -lwc {PATH}wiki.valid.tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(371, 12981, 1, 2088628)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT = data.Field(lower=True)\n",
    "FILES = dict(train='wiki.train.tokens', validation='wiki.valid.tokens', test='wiki.test.tokens')\n",
    "bs,bptt = 80,70\n",
    "md = LanguageModelData(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n",
    "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#md.trn_ds[0].text[:12], next(iter(md.trn_dl))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "em_sz = 200\n",
    "nh = 500\n",
    "nl = 3\n",
    "learner = md.get_model(SGD_Momentum(0.7), bs, em_sz, nh, nl)\n",
    "reg_fn=partial(seq2seq_reg, alpha=2, beta=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a88f0d1d9d1d443e81d3cdc61027f8cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.      5.8902  5.5658]                                   \n",
      "\n"
     ]
    }
   ],
   "source": [
    "clip=0.3\n",
    "learner.fit(10, 1, wds=1e-6, reg_fn=reg_fn, clip=clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21dd231e1f8e48419cc3c17cf43f0ecb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.      5.5214  5.3178]                                   \n",
      "[ 1.      5.36    5.1435]                                   \n",
      "[ 2.      5.2619  5.0643]                                   \n",
      "[ 3.      5.166   4.9541]                                   \n",
      "[ 4.      5.0275  4.8416]                                   \n",
      "[ 5.      4.9453  4.7837]                                   \n",
      "[ 6.      4.9306  4.7692]                                   \n",
      "[ 7.      4.9195  4.7462]                                   \n",
      "[ 8.      4.8391  4.6762]                                   \n",
      "[ 9.      4.767   4.6317]                                   \n",
      "[ 10.       4.718    4.5859]                                \n",
      "[ 11.       4.6623   4.5497]                                \n",
      "[ 12.       4.6262   4.5275]                                \n",
      "[ 13.       4.6103   4.5196]                                \n",
      "[ 14.       4.6206   4.5137]                                \n",
      "[ 15.       4.6716   4.5614]                                \n",
      "[ 16.       4.6237   4.5333]                                \n",
      "[ 17.       4.5918   4.4998]                                \n",
      "[ 18.       4.5519   4.47  ]                                \n",
      "[ 19.       4.5188   4.4514]                                \n",
      "[ 20.       4.4907   4.4342]                                \n",
      "[ 21.       4.464    4.4094]                                \n",
      "[ 22.       4.4473   4.4016]                                \n",
      "[ 23.       4.4196   4.3848]                                \n",
      "[ 24.       4.3956   4.3653]                                \n",
      "[ 25.       4.384    4.3596]                                \n",
      "[ 26.       4.3692   4.3493]                                \n",
      "[ 27.       4.3546   4.3422]                                \n",
      "[ 28.       4.3548   4.34  ]                                \n",
      "[ 29.       4.3457   4.3364]                                \n",
      "[ 30.       4.3517   4.3371]                                \n",
      "[ 31.       4.4461   4.401 ]                                \n",
      "[ 32.       4.4343   4.39  ]                                \n",
      "[ 33.       4.4064   4.3759]                                \n",
      "[ 34.      4.401   4.374]                                   \n",
      "[ 35.       4.3683   4.3583]                                \n",
      "[ 36.       4.3535   4.3478]                                \n",
      "[ 37.       4.3403   4.3402]                                \n",
      "[ 38.       4.3303   4.3356]                                \n",
      "[ 39.       4.3214   4.3346]                                \n",
      "[ 40.       4.2934   4.3226]                                \n",
      "[ 41.       4.2919   4.3188]                                \n",
      "[ 42.       4.2742   4.3055]                                \n",
      "[ 43.       4.2566   4.3009]                                \n",
      "[ 44.       4.2374   4.2973]                                \n",
      "[ 45.       4.2297   4.2854]                                \n",
      "[ 46.       4.217    4.2823]                                \n",
      "[ 47.       4.2086   4.2732]                                \n",
      "[ 48.       4.1899   4.2642]                                \n",
      "[ 49.       4.196    4.2624]                                \n",
      "[ 50.       4.1726   4.2562]                                \n",
      "[ 51.       4.182    4.2528]                                \n",
      "[ 52.       4.1619   4.2527]                                \n",
      "[ 53.       4.1868   4.248 ]                                \n",
      "[ 54.       4.1704   4.2455]                                \n",
      "[ 55.       4.156    4.2425]                                \n",
      "[ 56.       4.138    4.2428]                                \n",
      "[ 57.       4.1476   4.2367]                                \n",
      "[ 58.       4.1454   4.2349]                                \n",
      "[ 59.       4.1273   4.2312]                                \n",
      "[ 60.       4.1331   4.2317]                                \n",
      "[ 61.       4.1471   4.231 ]                                \n",
      "[ 62.       4.1283   4.2335]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save('lm_420')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0113c5be9cd24acd98dddb2de59966e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.      4.2051  4.242 ]                                   \n",
      "[ 1.      4.2466  4.2633]                                   \n",
      "[ 2.      4.2073  4.2377]                                   \n",
      "[ 3.      4.2657  4.2765]                                   \n",
      "[ 4.      4.2209  4.2627]                                   \n",
      "[ 5.      4.1965  4.2458]                                   \n",
      "[ 6.      4.1746  4.2335]                                   \n",
      "[ 7.      4.2657  4.2897]                                   \n",
      "[ 8.      4.2401  4.277 ]                                   \n",
      "[ 9.      4.2266  4.269 ]                                   \n",
      "[ 10.       4.189    4.2546]                                \n",
      "[ 11.       4.1639   4.2369]                                \n",
      "[ 12.       4.1531   4.2288]                                \n",
      "[ 13.       4.1233   4.2223]                                \n",
      "[ 14.       4.1328   4.2161]                                \n",
      "[ 15.       4.2515   4.2909]                                \n",
      "[ 16.       4.2246   4.2872]                                \n",
      "[ 17.       4.2081   4.2727]                                \n",
      "[ 18.       4.1989   4.2651]                                \n",
      "[ 19.       4.1836   4.2634]                                \n",
      "[ 20.       4.1853   4.2647]                                \n",
      "[ 21.       4.1435   4.2518]                                \n",
      "[ 22.       4.1407   4.2388]                                \n",
      "[ 23.       4.1227   4.2285]                                \n",
      "[ 24.       4.1024   4.2244]                                \n",
      "[ 25.       4.0866   4.218 ]                                \n",
      "[ 26.       4.0916   4.2135]                                \n",
      "[ 27.       4.0695   4.2084]                                \n",
      "[ 28.       4.0615   4.2036]                                \n",
      "[ 29.       4.0672   4.2034]                                \n",
      "[ 30.       4.0859   4.2014]                                \n",
      "[ 31.       4.1937   4.2731]                                \n",
      "[ 32.       4.1961   4.2628]                                \n",
      "[ 33.       4.1725   4.2582]                                \n",
      "[ 34.       4.1909   4.2662]                                \n",
      "[ 35.       4.1622   4.2586]                                \n",
      "[ 36.       4.1525   4.2556]                                \n",
      "[ 37.       4.1601   4.2616]                                \n",
      "[ 38.       4.1521   4.2604]                                \n",
      "[ 39.       4.1471   4.2571]                                \n",
      "[ 40.       4.1131   4.2425]                                \n",
      "[ 41.       4.1072   4.2362]                                \n",
      "[ 42.       4.0949   4.2399]                                \n",
      "[ 43.       4.1014   4.2315]                                \n",
      "[ 44.       4.0926   4.2266]                                \n",
      "[ 45.       4.0721   4.2225]                                \n",
      "[ 46.       4.0694   4.219 ]                                \n",
      "[ 47.       4.0742   4.2188]                                \n",
      "[ 48.       4.0433   4.2181]                                \n",
      "[ 49.       4.0666   4.2083]                                \n",
      "[ 50.       4.0395   4.2066]                                \n",
      "[ 51.       4.0591   4.2007]                                \n",
      "[ 52.       4.0166   4.2046]                                \n",
      "[ 53.       4.0279   4.1973]                                \n",
      "[ 54.       4.0178   4.1957]                                \n",
      "[ 55.       4.0241   4.1937]                                \n",
      "[ 56.       3.9977   4.1901]                                \n",
      "[ 57.       3.997    4.1909]                                \n",
      "[ 58.       4.0194   4.189 ]                                \n",
      "[ 59.       3.9832   4.1833]                                \n",
      "[ 60.       4.0069   4.1856]                                \n",
      "[ 61.       3.9893   4.1873]                                \n",
      "[ 62.       3.986    4.1828]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save('lm_419')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ff63287f1ed47f2afe629863278982a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.      4.0781  4.2001]                                   \n",
      "[ 1.      4.1227  4.2212]                                   \n",
      "[ 2.      4.1012  4.1981]                                   \n",
      "[ 3.      4.1524  4.243 ]                                   \n",
      "[ 4.      4.1137  4.2235]                                   \n",
      "[ 5.      4.0761  4.1977]                                   \n",
      "[ 6.      4.072   4.1966]                                   \n",
      "[ 7.      4.1489  4.2491]                                   \n",
      "[ 8.      4.1355  4.2431]                                   \n",
      "[ 9.      4.1167  4.2323]                                   \n",
      "[ 10.       4.0797   4.2189]                                \n",
      "[ 11.       4.048    4.2054]                                \n",
      "[ 12.       4.0284   4.2013]                                \n",
      "[ 13.       4.0436   4.1854]                                \n",
      "[ 14.       4.0416   4.1864]                                \n",
      "[ 15.       4.131    4.2443]                                \n",
      "[ 16.       4.1205   4.2414]                                \n",
      "[ 17.       4.116    4.2427]                                \n",
      "[ 18.       4.1111   4.2443]                                \n",
      "[ 19.       4.1028   4.2313]                                \n",
      "[ 20.       4.0806   4.224 ]                                \n",
      "[ 21.       4.0879   4.2192]                                \n",
      "[ 22.       4.0452   4.2083]                                \n",
      "[ 23.       4.0272   4.2042]                                \n",
      "[ 24.       4.0196   4.2019]                                \n",
      "[ 25.       4.0273   4.1909]                                \n",
      "[ 26.       4.0178   4.1868]                                \n",
      "[ 27.       3.9833   4.1847]                                \n",
      "[ 28.       3.9777   4.1805]                                \n",
      "[ 29.       3.9841   4.1805]                                \n",
      "[ 30.       4.0016   4.181 ]                                \n",
      "[ 31.       4.1057   4.236 ]                                \n",
      "[ 32.       4.126    4.2456]                                \n",
      "[ 33.       4.1018   4.2358]                                \n",
      "[ 34.       4.0977   4.2421]                                \n",
      "[ 35.       4.1049   4.2377]                                \n",
      "[ 36.       4.1079   4.2429]                                \n",
      "[ 37.       4.0851   4.2365]                                \n",
      "[ 38.       4.0778   4.2304]                                \n",
      "[ 39.       4.0849   4.2331]                                \n",
      "[ 40.       4.058    4.2278]                                \n",
      "[ 41.       4.047    4.2225]                                \n",
      "[ 42.       4.0585   4.2251]                                \n",
      "[ 43.       4.0324   4.2226]                                \n",
      "[ 44.       4.0459   4.2161]                                \n",
      "[ 45.       4.0147   4.2054]                                \n",
      "[ 46.       4.0259   4.2097]                                \n",
      "[ 47.       3.9967   4.205 ]                                \n",
      "[ 48.       4.0014   4.1984]                                \n",
      "[ 49.       3.9926   4.1985]                                \n",
      "[ 50.      3.968   4.195]                                   \n",
      "[ 51.       3.98     4.1888]                                \n",
      "[ 52.       3.9527   4.1912]                                \n",
      "[ 53.       3.9452   4.1864]                                \n",
      "[ 54.       3.9587   4.1848]                                \n",
      "[ 55.       3.9286   4.184 ]                                \n",
      "[ 56.       3.9302   4.1778]                                \n",
      "[ 57.       3.9271   4.1751]                                \n",
      "[ 58.       3.9452   4.1743]                                \n",
      "[ 59.       3.9443   4.1792]                                \n",
      "[ 60.      3.924   4.176]                                   \n",
      "[ 61.       3.9347   4.1745]                                \n",
      "[ 62.       3.924    4.1718]                                \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save('lm_418')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64.71545210740304"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "math.exp(4.17)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m=learner.model\n",
    "s=[\"\"\". <eos> The game began development in 2010 , carrying over a large portion of the work \n",
    "done on Valkyria Chronicles II . While it retained the standard features of \"\"\".split()]\n",
    "t=TEXT.numericalize(s)\n",
    "\n",
    "m[0].bs=1\n",
    "m.reset(False)\n",
    "res,*_ = m(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['the', '<unk>', 'a', '\"', 'an', 'its', 'this', 'their', 'all', 'other']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nexts = torch.topk(res[-1], 10)[1]\n",
    "[TEXT.vocab.itos[o] for o in to_np(nexts)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the film , and the first of the two @-@ year @-@ old , the \" black \" , a "
     ]
    }
   ],
   "source": [
    "for i in range(20):\n",
    "    n=res[-1].topk(2)[1]\n",
    "    n = n[1] if n.data[0]==0 else n[0]\n",
    "    print(TEXT.vocab.itos[n.data[0]], end=' ')\n",
    "    res,*_ = m(n[0].unsqueeze(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### End"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}