{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai.model import fit\n",
    "from fastai.dataset import *\n",
    "\n",
    "import torchtext\n",
    "from torchtext import vocab, data\n",
    "from torchtext.datasets import language_modeling\n",
    "\n",
    "from fastai.rnn_reg import *\n",
    "from fastai.rnn_train import *\n",
    "from fastai.nlp import *\n",
    "from fastai.lm_rnn import *\n",
    "\n",
    "import dill as pickle\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs,bptt = 64,70"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, requests, time\n",
    "# feedparser isn't a fastai dependency so you may need to install it.\n",
    "import feedparser\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "class GetArXiv(object):\n",
    "    def __init__(self, pickle_path, categories=list()):\n",
    "        \"\"\"\n",
    "        :param pickle_path (str): path to pickle data file to save/load\n",
    "        :param pickle_name (str): file name to save pickle to path\n",
    "        :param categories (list): arXiv categories to query\n",
    "        \"\"\"\n",
    "        if os.path.isdir(pickle_path):\n",
    "            pickle_path = f\"{pickle_path}{'' if pickle_path[-1] == '/' else '/'}all_arxiv.pkl\"\n",
    "        if len(categories) < 1:\n",
    "            categories = ['cs*', 'cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']\n",
    "        # categories += ['cs.CV', 'cs.AI', 'cs.LG', 'cs.CL']\n",
    "\n",
    "        self.categories = categories\n",
    "        self.pickle_path = pickle_path\n",
    "        self.base_url = 'http://export.arxiv.org/api/query'\n",
    "\n",
    "    @staticmethod\n",
    "    def build_qs(categories):\n",
    "        \"\"\"Build query string from categories\"\"\"\n",
    "        return '+OR+'.join(['cat:'+c for c in categories])\n",
    "\n",
    "    @staticmethod\n",
    "    def get_entry_dict(entry):\n",
    "        \"\"\"Return a dictionary with the items we want from a feedparser entry\"\"\"\n",
    "        try:\n",
    "            return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],\n",
    "                        published=pd.Timestamp(entry['published']), summary=entry['summary'],\n",
    "                        link=entry['link'], category=entry['category'])\n",
    "        except KeyError:\n",
    "            print('Missing keys in row: {}'.format(entry))\n",
    "            return None\n",
    "\n",
    "    @staticmethod\n",
    "    def strip_version(link):\n",
    "        \"\"\"Strip version number from arXiv paper link\"\"\"\n",
    "        return link[:-2]\n",
    "\n",
    "    def fetch_updated_data(self, max_retry=5, pg_offset=0, pg_size=1000, wait_time=15):\n",
    "        \"\"\"\n",
    "        Get new papers from arXiv server\n",
    "        :param max_retry: max number of time to retry request\n",
    "        :param pg_offset: number of pages to offset\n",
    "        :param pg_size: num abstracts to fetch per request\n",
    "        :param wait_time: num seconds to wait between requests\n",
    "        \"\"\"\n",
    "        i, retry = pg_offset, 0\n",
    "        df = pd.DataFrame()\n",
    "        past_links = []\n",
    "        if os.path.isfile(self.pickle_path):\n",
    "            df = pd.read_pickle(self.pickle_path)\n",
    "            df.reset_index()\n",
    "        if len(df) > 0: past_links = df.link.apply(self.strip_version)\n",
    "\n",
    "        while True:\n",
    "            params = dict(search_query=self.build_qs(self.categories),\n",
    "                          sortBy='submittedDate', start=pg_size*i, max_results=pg_size)\n",
    "            response = requests.get(self.base_url, params='&'.join([f'{k}={v}' for k, v in params.items()]))\n",
    "            entries = feedparser.parse(response.text).entries\n",
    "            if len(entries) < 1:\n",
    "                if retry < max_retry:\n",
    "                    retry += 1\n",
    "                    time.sleep(wait_time)\n",
    "                    continue\n",
    "                break\n",
    "\n",
    "            results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])\n",
    "            max_date = results_df.published.max().date()\n",
    "            new_links = ~results_df.link.apply(self.strip_version).isin(past_links)\n",
    "            print(f'{i}. Fetched {len(results_df)} abstracts published {max_date} and earlier')\n",
    "            if not new_links.any():\n",
    "                break\n",
    "\n",
    "            df = pd.concat((df, results_df.loc[new_links]), ignore_index=True)\n",
    "            i += 1\n",
    "            retry = 0\n",
    "            time.sleep(wait_time)\n",
    "\n",
    "        print(f'Downloaded {len(df)-len(past_links)} new abstracts')\n",
    "        df.sort_values('published', ascending=False).groupby('link').first().reset_index()\n",
    "        df.to_pickle(self.pickle_path)\n",
    "        return df\n",
    "\n",
    "    @classmethod\n",
    "    def load(cls, pickle_path):\n",
    "        \"\"\"Load data from pickle and remove duplicates\"\"\"\n",
    "        return pd.read_pickle(cls(pickle_path).pickle_path)\n",
    "\n",
    "    @classmethod\n",
    "    def update(cls, pickle_path, categories=list(), **kwargs):\n",
    "        \"\"\"\n",
    "        Update arXiv data pickle with the latest abstracts\n",
    "        \"\"\"\n",
    "        cls(pickle_path, categories).fetch_updated_data(**kwargs)\n",
    "        return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH='data/arxiv/'\n",
    "\n",
    "ALL_ARXIV = f'{PATH}all_arxiv.pkl'\n",
    "\n",
    "# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee\n",
    "if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)\n",
    "\n",
    "# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one\n",
    "df_mb = pd.read_csv(f'{PATH}arxiv.csv')\n",
    "df_all = pd.read_pickle(ALL_ARXIV)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "49600"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_txt(df):\n",
    "    return '<CAT> ' + df.category.str.replace(r'[\\.\\-]','') + ' <SUMM> ' + df.summary + ' <TITLE> ' + df.title\n",
    "df_mb['txt'] = get_txt(df_mb)\n",
    "df_all['txt'] = get_txt(df_all)\n",
    "n=len(df_all); n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'{PATH}trn/yes', exist_ok=True)\n",
    "os.makedirs(f'{PATH}val/yes', exist_ok=True)\n",
    "os.makedirs(f'{PATH}trn/no', exist_ok=True)\n",
    "os.makedirs(f'{PATH}val/no', exist_ok=True)\n",
    "os.makedirs(f'{PATH}all/trn', exist_ok=True)\n",
    "os.makedirs(f'{PATH}all/val', exist_ok=True)\n",
    "os.makedirs(f'{PATH}models', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (i,(_,r)) in enumerate(df_all.iterrows()):\n",
    "    dset = 'trn' if random.random()>0.1 else 'val'\n",
    "    open(f'{PATH}all/{dset}/{i}.txt', 'w').write(r['txt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (i,(_,r)) in enumerate(df_mb.iterrows()):\n",
    "    lbl = 'yes' if r.tweeted else 'no'\n",
    "    dset = 'trn' if random.random()>0.1 else 'val'\n",
    "    open(f'{PATH}{dset}/{lbl}/{i}.txt', 'w').write(r['txt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from spacy.symbols import ORTH\n",
    "\n",
    "# install the 'en' model if the next line of code fails by running:\n",
    "#python -m spacy download en              # default English model (~50MB)\n",
    "#python -m spacy download en_core_web_md  # larger English model (~1GB)\n",
    "my_tok = spacy.load('en')\n",
    "\n",
    "my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])\n",
    "my_tok.tokenizer.add_special_case('<CAT>', [{ORTH: '<CAT>'}])\n",
    "my_tok.tokenizer.add_special_case('<TITLE>', [{ORTH: '<TITLE>'}])\n",
    "my_tok.tokenizer.add_special_case('<BR />', [{ORTH: '<BR />'}])\n",
    "my_tok.tokenizer.add_special_case('<BR>', [{ORTH: '<BR>'}])\n",
    "\n",
    "def my_spacy_tok(x): return [tok.text for tok in my_tok.tokenizer(x)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXT = data.Field(lower=True, tokenize=my_spacy_tok)\n",
    "FILES = dict(train='trn', validation='val', test='val')\n",
    "md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n",
    "pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2129, 17951, 1, 9543290)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<unk>', '<pad>', '\\n', 'the', ',', '.', 'of', '-', 'and', 'a', 'to', 'in']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT.vocab.itos[:12]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<cat> csni <summ> the exploitation of mm - wave bands is one of the key - enabler for 5 g mobile \\n radio networks . however , the introduction of mm - wave technologies in cellular \\n networks is not straightforward due to harsh propagation conditions that limit \\n the mm - wave access availability . mm - wave technologies require high - gain antenna \\n systems to compensate for high path loss and limited power . as a consequence , \\n directional transmissions must be used for cell discovery and synchronization \\n processes : this can lead to a non - negligible access delay caused by the \\n exploration of the cell area with multiple transmissions along different \\n directions . \\n    the integration of mm - wave technologies and conventional wireless access \\n networks with the objective of speeding up the cell search process requires new \\n'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join(md.trn_ds[0].text[:150])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "em_sz = 200\n",
    "nh = 500\n",
    "nl = 3\n",
    "opt_fn = partial(optim.Adam, betas=(0.7, 0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = md.get_model(opt_fn, em_sz, nh, nl,\n",
    "    dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)\n",
    "# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5\n",
    "#                dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)\n",
    "learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n",
    "learner.clip=0.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76e5d02069da4309a92517b46519b2c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       4.36189  4.2185 ]                                  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(3e-3, 1, wds=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "040b60fa7679477caea6460046ba9577",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       4.11236  3.99207]                                  \n",
      "[ 1.       4.03207  3.89298]                                  \n",
      "[ 2.       3.91653  3.81915]                                  \n",
      "[ 3.       3.97808  3.8428 ]                                  \n",
      "[ 4.       3.88482  3.76226]                                  \n",
      "[ 5.       3.79955  3.70472]                                  \n",
      "[ 6.       3.75721  3.69048]                                  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(3e-3, 3, wds=1e-6, cycle_len=1, cycle_mult=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save_encoder('adam2_enc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21b13ef5eb7f4ba6bddb0059f5ec85d0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       3.89388  3.76575]                                  \n",
      "[ 1.       3.82548  3.71875]                                  \n",
      "[ 2.       3.76471  3.66974]                                  \n",
      "[ 3.       3.71713  3.63861]                                  \n",
      "[ 4.       3.67534  3.62983]                                  \n",
      "[ 5.       3.83938  3.71551]                                  \n",
      "[ 6.       3.78093  3.68056]                                  \n",
      "[ 7.       3.72828  3.63638]                                  \n",
      "[ 8.       3.66743  3.60355]                                  \n",
      "[ 9.       3.65793  3.59448]                                  \n",
      "[ 10.        3.80545   3.68213]                               \n",
      "[ 11.        3.75299   3.65219]                               \n",
      "[ 12.        3.7057    3.61324]                               \n",
      "[ 13.        3.63348   3.58048]                               \n",
      "[ 14.        3.62304   3.57257]                               \n",
      "[ 15.        3.78656   3.66324]                               \n",
      "[ 16.        3.73522   3.63348]                               \n",
      "[ 17.        3.67258   3.59369]                               \n",
      "[ 18.        3.6242    3.56674]                               \n",
      "[ 19.        3.61123   3.55783]                               \n",
      "[ 20.        3.77443   3.65472]                               \n",
      "[ 21.        3.7374    3.62169]                               \n",
      "[ 22.        3.68367   3.58247]                               \n",
      "[ 23.        3.61606   3.55567]                               \n",
      "[ 24.        3.58527   3.54725]                               \n",
      "[ 25.        3.75456   3.63861]                               \n",
      "[ 26.        3.72061   3.61084]                               \n",
      "[ 27.        3.65141   3.57073]                               \n",
      "[ 28.        3.59711   3.54414]                               \n",
      "[ 29.        3.57052   3.53622]                               \n",
      "[ 30.        3.75229   3.62935]                               \n",
      "[ 31.        3.70693   3.60198]                               \n",
      "[ 32.        3.65193   3.56444]                               \n",
      "[ 33.        3.59173   3.53742]                               \n",
      "[ 34.        3.58699   3.53152]                               \n",
      "[ 35.        3.74211   3.62154]                               \n",
      "[ 36.        3.70016   3.59831]                               \n",
      "[ 37.        3.64095   3.55689]                               \n",
      "[ 38.        3.60686   3.53296]                               \n",
      "[ 39.       3.5627   3.523 ]                                  \n",
      "[ 40.        3.72944   3.61956]                               \n",
      "[ 41.        3.68161   3.58779]                               \n",
      "[ 42.        3.62305   3.55187]                               \n",
      "[ 43.        3.58559   3.52524]                               \n",
      "[ 44.        3.56087   3.51682]                               \n",
      "[ 45.        3.72533   3.61458]                               \n",
      "[ 46.        3.68025   3.58452]                               \n",
      "[ 47.        3.64447   3.55002]                               \n",
      "[ 48.        3.575     3.52066]                               \n",
      "[ 49.        3.54424   3.5133 ]                               \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(3e-3, 10, wds=1e-6, cycle_len=5, cycle_save_name='adam3_10')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save_encoder('adam3_10_enc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9df7476b9cc14a5892bc72f30f34e856",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       3.70587  3.61666]                                  \n",
      "[ 1.       3.71738  3.61174]                                  \n",
      "[ 2.       3.68606  3.59661]                                  \n",
      "[ 3.       3.65407  3.5742 ]                                  \n",
      "[ 4.       3.62901  3.55795]                                  \n",
      "[ 5.       3.59921  3.53632]                                  \n",
      "[ 6.       3.58401  3.52149]                                  \n",
      "[ 7.       3.55126  3.50797]                                  \n",
      "[ 8.       3.52965  3.50178]                                  \n",
      "[ 9.       3.52336  3.49997]                                  \n",
      "[ 10.        3.7109    3.60817]                               \n",
      "[ 11.        3.69879   3.60047]                               \n",
      "[ 12.        3.6735    3.58623]                               \n",
      "[ 13.        3.64365   3.56568]                               \n",
      "[ 14.        3.6099    3.54776]                               \n",
      "[ 15.        3.58244   3.52829]                               \n",
      "[ 16.        3.54894   3.51071]                               \n",
      "[ 17.        3.52702   3.50173]                               \n",
      "[ 18.        3.51357   3.49522]                               \n",
      "[ 19.        3.50302   3.49272]                               \n",
      "[ 20.        3.72505   3.60198]                               \n",
      "[ 21.        3.70037   3.59914]                               \n",
      "[ 22.        3.68386   3.58279]                               \n",
      "[ 23.        3.64176   3.56435]                               \n",
      "[ 24.        3.60259   3.54304]                               \n",
      "[ 25.        3.58669   3.52432]                               \n",
      "[ 26.        3.54075   3.50703]                               \n",
      "[ 27.        3.50951   3.49534]                               \n",
      "[ 28.        3.51915   3.4896 ]                               \n",
      "[ 29.        3.48695   3.48968]                               \n",
      "[ 30.        3.70563   3.59631]                               \n",
      "[ 31.        3.68822   3.58723]                               \n",
      "[ 32.        3.67549   3.58141]                               \n",
      "[ 33.        3.63267   3.55537]                               \n",
      "[ 34.        3.60638   3.5386 ]                               \n",
      "[ 35.        3.58803   3.52154]                               \n",
      "[ 36.        3.53987   3.50394]                               \n",
      "[ 37.        3.51036   3.49244]                               \n",
      "[ 38.        3.48651   3.48652]                               \n",
      "[ 39.        3.49061   3.48673]                               \n",
      "[ 40.        3.70093   3.59211]                               \n",
      "[ 41.        3.67371   3.58516]                               \n",
      "[ 42.        3.66558   3.57032]                               \n",
      "[ 43.        3.65089   3.55939]                               \n",
      "[ 44.        3.59885   3.53445]                               \n",
      "[ 45.        3.56369   3.51585]                               \n",
      "[ 46.        3.55304   3.50237]                               \n",
      "[ 47.        3.50469   3.48919]                               \n",
      "[ 48.        3.49559   3.48289]                               \n",
      "[ 49.        3.50912   3.48136]                               \n",
      "[ 50.        3.70603   3.59182]                               \n",
      "[ 51.        3.669     3.58069]                               \n",
      "[ 52.        3.64965   3.56896]                               \n",
      "[ 53.        3.62839   3.55251]                               \n",
      "[ 54.        3.59578   3.53297]                               \n",
      "[ 55.        3.55814   3.51205]                               \n",
      "[ 56.        3.53653   3.49682]                               \n",
      "[ 57.        3.50043   3.48502]                               \n",
      "[ 58.        3.49535   3.4797 ]                               \n",
      "[ 59.        3.48039   3.47882]                               \n",
      "[ 60.        3.68319   3.58874]                               \n",
      "[ 61.        3.68893   3.58173]                               \n",
      "[ 62.        3.6516    3.56403]                               \n",
      "[ 63.        3.63432   3.55047]                               \n",
      "[ 64.        3.59697   3.52815]                               \n",
      "[ 65.        3.55784   3.50832]                               \n",
      "[ 66.        3.52815   3.49319]                               \n",
      "[ 67.        3.50618   3.48222]                               \n",
      "[ 68.        3.48319   3.47645]                               \n",
      "[ 69.        3.49879   3.47596]                               \n",
      "[ 70.        3.68466   3.58318]                               \n",
      "[ 71.        3.67045   3.57351]                               \n",
      "[ 72.        3.64409   3.5606 ]                               \n",
      "[ 73.        3.61991   3.54552]                               \n",
      "[ 74.        3.60503   3.52782]                               \n",
      "[ 75.        3.56681   3.50743]                               \n",
      "[ 76.        3.52401   3.49046]                               \n",
      "[ 77.        3.50519   3.47875]                               \n",
      "[ 78.        3.49343   3.47452]                               \n",
      "[ 79.        3.49275   3.47175]                               \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(3e-3, 8, wds=1e-6, cycle_len=10, cycle_save_name='adam3_5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "717717efd1e441029d5618ed0be24416",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       3.47841  3.4751 ]                                  \n",
      "[ 1.       3.69717  3.57883]                                  \n",
      "[ 2.       3.68267  3.57793]                                  \n",
      "[ 3.       3.66797  3.57299]                                  \n",
      "[ 4.       3.66805  3.56847]                                  \n",
      "[ 5.       3.63489  3.56238]                                  \n",
      "[ 6.       3.62479  3.54928]                                  \n",
      "[ 7.       3.60663  3.53879]                                  \n",
      "[ 8.       3.59124  3.53175]                                  \n",
      "[ 9.       3.58617  3.52009]                                  \n",
      "[ 10.        3.56924   3.51174]                               \n",
      "[ 11.        3.5509    3.49974]                               \n",
      "[ 12.        3.51595   3.49008]                               \n",
      "[ 13.        3.50939   3.48222]                               \n",
      "[ 14.        3.48886   3.47952]                               \n",
      "[ 15.        3.4676    3.47311]                               \n",
      "[ 16.        3.4856    3.46577]                               \n",
      "[ 17.        3.44909   3.46499]                               \n",
      "[ 18.        3.46791   3.46314]                               \n",
      "[ 19.        3.44028   3.46231]                               \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learner.fit(3e-3, 1, wds=1e-6, cycle_len=20, cycle_save_name='adam3_20')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save_encoder('adam3_20_enc')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save('adam3_20')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def proc_str(s): return TEXT.preprocess(TEXT.tokenize(s))\n",
    "def num_str(s): return TEXT.numericalize([proc_str(s)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m=learner.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s=\"\"\"<CAT> cscv <SUMM> algorithms that\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_model(m, s, l=50):\n",
    "    t = num_str(s)\n",
    "    m[0].bs=1\n",
    "    m.eval()\n",
    "    m.reset()\n",
    "    res,*_ = m(t)\n",
    "    print('...', end='')\n",
    "\n",
    "    for i in range(l):\n",
    "        n=res[-1].topk(2)[1]\n",
    "        n = n[1] if n.data[0]==0 else n[0]\n",
    "        word = TEXT.vocab.itos[n.data[0]]\n",
    "        print(word, end=' ')\n",
    "        if word=='<eos>': break\n",
    "        res,*_ = m(n[0].unsqueeze(0))\n",
    "\n",
    "    m[0].bs=bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...use the same network as a single node are not able to \n",
      " achieve the same performance as the traditional network - based routing \n",
      " algorithms . in this paper , we propose a novel routing scheme for routing \n",
      " protocols in wireless networks . the proposed scheme is based ..."
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> csni <SUMM> algorithms that\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...use the same data to perform image classification are \n",
      " increasingly being used to improve the performance of image classification \n",
      " algorithms . in this paper , we propose a novel method for image classification \n",
      " using a deep convolutional neural network ( cnn ) . the proposed method is ..."
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> cscv <SUMM> algorithms that\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...the performance of deep learning for image classification <eos> "
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> cscv <SUMM> algorithms. <TITLE> on \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...the performance of wireless networks <eos> "
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> csni <SUMM> algorithms. <TITLE> on \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...a new approach to image classification <eos> "
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> cscv <SUMM> algorithms. <TITLE> towards \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...a new approach to the analysis of wireless networks <eos> "
     ]
    }
   ],
   "source": [
    "sample_model(m,\"<CAT> csni <SUMM> algorithms. <TITLE> towards \")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sentiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEXT = pickle.load(open(f'{PATH}models/TEXT.pkl','rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ArxivDataset(torchtext.data.Dataset):\n",
    "    def __init__(self, path, text_field, label_field, **kwargs):\n",
    "        fields = [('text', text_field), ('label', label_field)]\n",
    "        examples = []\n",
    "        for label in ['yes', 'no']:\n",
    "            fnames = glob(os.path.join(path, label, '*.txt'));\n",
    "            assert fnames, f\"can't find 'yes.txt' or 'no.txt' under {path}/{label}\"\n",
    "            for fname in fnames:\n",
    "                with open(fname, 'r') as f: text = f.readline()\n",
    "                examples.append(data.Example.fromlist([text, label], fields))\n",
    "        super().__init__(examples, fields, **kwargs)\n",
    "\n",
    "    @staticmethod\n",
    "    def sort_key(ex): return len(ex.text)\n",
    "    \n",
    "    @classmethod\n",
    "    def splits(cls, text_field, label_field, root='.data',\n",
    "               train='train', test='test', **kwargs):\n",
    "        return super().splits(\n",
    "            root, text_field=text_field, label_field=label_field,\n",
    "            train=train, validation=None, test=test, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ARX_LABEL = data.Field(sequential=False)\n",
    "splits = ArxivDataset.splits(TEXT, ARX_LABEL, PATH, train='trn', test='val')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "md2 = TextData.from_splits(PATH, splits, bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#            dropout=0.3, dropouti=0.4, wdrop=0.3, dropoute=0.05, dropouth=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_curve\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def prec_at_6(preds,targs):\n",
    "    precision, recall, _ = precision_recall_curve(targs==2, preds[:,2])\n",
    "    print(recall[precision>=0.6][0])\n",
    "    return recall[precision>=0.6][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5\n",
    "m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl, \n",
    "           dropout=0.1, dropouti=0.65, wdrop=0.5, dropoute=0.1, dropouth=0.3)\n",
    "m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n",
    "m3.clip=25."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this notebook has a mess of some things going under 'all/' others not, so a little hack here\n",
    "!ln -sf ../all/models/adam3_20_enc.h5 {PATH}models/adam3_20_enc.h5\n",
    "m3.load_encoder(f'adam3_20_enc')\n",
    "lrs=np.array([1e-4,1e-3,1e-3,1e-2,3e-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "feeaf546b6ed445d91985561124afc37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.47654  0.44322  0.78525]                         \n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23ab581c7a264367be284e88b28890b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.43033  0.40192  0.80087]                        \n",
      "\n"
     ]
    }
   ],
   "source": [
    "m3.freeze_to(-1)\n",
    "m3.fit(lrs/2, 1, metrics=[accuracy])\n",
    "m3.unfreeze()\n",
    "m3.fit(lrs, 1, metrics=[accuracy], cycle_len=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be7735a6a5354c279c4d6b8280bcda6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.42236  0.39006  0.8194 ]                        \n",
      "[ 1.       0.39477  0.37063  0.82086]                        \n",
      "[ 2.       0.39389  0.37082  0.82449]                        \n",
      "[ 3.       0.40728  0.36999  0.82195]                        \n",
      "[ 4.       0.39308  0.3675   0.81977]                        \n",
      "[ 5.       0.38662  0.36737  0.8234 ]                        \n",
      "[ 6.       0.39259  0.36512  0.82486]                        \n",
      "[ 7.       0.38047  0.36538  0.82522]                        \n",
      "\n"
     ]
    }
   ],
   "source": [
    "m3.fit(lrs, 2, metrics=[accuracy], cycle_len=4, cycle_save_name='imdb2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.659305993691\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.65930599369085174"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prec_at_6(*m3.predict_with_targs())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57d2433c996b48a99f86329570a9885a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.38752  0.36351  0.82486]                        \n",
      "[ 1.       0.38664  0.36123  0.82558]                        \n",
      "[ 2.       0.3904   0.36098  0.82486]                        \n",
      "[ 3.       0.37319  0.36144  0.82486]                        \n",
      "[ 4.       0.38074  0.36334  0.82595]                        \n",
      "[ 5.       0.36405  0.3594   0.82413]                        \n",
      "[ 6.       0.38781  0.35914  0.82522]                        \n",
      "[ 7.       0.37722  0.357    0.82631]                        \n",
      "\n"
     ]
    }
   ],
   "source": [
    "m3.fit(lrs, 4, metrics=[accuracy], cycle_len=2, cycle_save_name='imdb2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.695583596215\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.69558359621451105"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prec_at_6(*m3.predict_with_targs())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}