{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai.model import fit\n", "from fastai.dataset import *\n", "\n", "import torchtext\n", "from torchtext import vocab, data\n", "from torchtext.datasets import language_modeling\n", "\n", "from fastai.rnn_reg import *\n", "from fastai.rnn_train import *\n", "from fastai.nlp import *\n", "from fastai.lm_rnn import *\n", "\n", "import dill as pickle\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs,bptt = 64,70" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Language modeling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, requests, time\n", "# feedparser isn't a fastai dependency so you may need to install it.\n", "import feedparser\n", "import pandas as pd\n", "\n", "\n", "class GetArXiv(object):\n", " def __init__(self, pickle_path, categories=list()):\n", " \"\"\"\n", " :param pickle_path (str): path to pickle data file to save/load\n", " :param pickle_name (str): file name to save pickle to path\n", " :param categories (list): arXiv categories to query\n", " \"\"\"\n", " if os.path.isdir(pickle_path):\n", " pickle_path = f\"{pickle_path}{'' if pickle_path[-1] == '/' else '/'}all_arxiv.pkl\"\n", " if len(categories) < 1:\n", " categories = ['cs*', 'cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']\n", " # categories += ['cs.CV', 'cs.AI', 'cs.LG', 'cs.CL']\n", "\n", " self.categories = categories\n", " self.pickle_path = pickle_path\n", " self.base_url = 'http://export.arxiv.org/api/query'\n", "\n", " @staticmethod\n", " def build_qs(categories):\n", " \"\"\"Build query string from categories\"\"\"\n", " return '+OR+'.join(['cat:'+c for c in categories])\n", "\n", " @staticmethod\n", " def get_entry_dict(entry):\n", " \"\"\"Return a dictionary with the items we want from a feedparser entry\"\"\"\n", " try:\n", " return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],\n", " published=pd.Timestamp(entry['published']), summary=entry['summary'],\n", " link=entry['link'], category=entry['category'])\n", " except KeyError:\n", " print('Missing keys in row: {}'.format(entry))\n", " return None\n", "\n", " @staticmethod\n", " def strip_version(link):\n", " \"\"\"Strip version number from arXiv paper link\"\"\"\n", " return link[:-2]\n", "\n", " def fetch_updated_data(self, max_retry=5, pg_offset=0, pg_size=1000, wait_time=15):\n", " \"\"\"\n", " Get new papers from arXiv server\n", " :param max_retry: max number of time to retry request\n", " :param pg_offset: number of pages to offset\n", " :param pg_size: num abstracts to fetch per request\n", " :param wait_time: num seconds to wait between requests\n", " \"\"\"\n", " i, retry = pg_offset, 0\n", " df = pd.DataFrame()\n", " past_links = []\n", " if os.path.isfile(self.pickle_path):\n", " df = pd.read_pickle(self.pickle_path)\n", " df.reset_index()\n", " if len(df) > 0: past_links = df.link.apply(self.strip_version)\n", "\n", " while True:\n", " params = dict(search_query=self.build_qs(self.categories),\n", " sortBy='submittedDate', start=pg_size*i, max_results=pg_size)\n", " response = requests.get(self.base_url, params='&'.join([f'{k}={v}' for k, v in params.items()]))\n", " entries = feedparser.parse(response.text).entries\n", " if len(entries) < 1:\n", " if retry < max_retry:\n", " retry += 1\n", " time.sleep(wait_time)\n", " continue\n", " break\n", "\n", " results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])\n", " max_date = results_df.published.max().date()\n", " new_links = ~results_df.link.apply(self.strip_version).isin(past_links)\n", " print(f'{i}. Fetched {len(results_df)} abstracts published {max_date} and earlier')\n", " if not new_links.any():\n", " break\n", "\n", " df = pd.concat((df, results_df.loc[new_links]), ignore_index=True)\n", " i += 1\n", " retry = 0\n", " time.sleep(wait_time)\n", "\n", " print(f'Downloaded {len(df)-len(past_links)} new abstracts')\n", " df.sort_values('published', ascending=False).groupby('link').first().reset_index()\n", " df.to_pickle(self.pickle_path)\n", " return df\n", "\n", " @classmethod\n", " def load(cls, pickle_path):\n", " \"\"\"Load data from pickle and remove duplicates\"\"\"\n", " return pd.read_pickle(cls(pickle_path).pickle_path)\n", "\n", " @classmethod\n", " def update(cls, pickle_path, categories=list(), **kwargs):\n", " \"\"\"\n", " Update arXiv data pickle with the latest abstracts\n", " \"\"\"\n", " cls(pickle_path, categories).fetch_updated_data(**kwargs)\n", " return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH='data/arxiv/'\n", "\n", "ALL_ARXIV = f'{PATH}all_arxiv.pkl'\n", "\n", "# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee\n", "if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)\n", "\n", "# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one\n", "df_mb = pd.read_csv(f'{PATH}arxiv.csv')\n", "df_all = pd.read_pickle(ALL_ARXIV)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "49600" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_txt(df):\n", " return ' ' + df.category.str.replace(r'[\\.\\-]','') + ' ' + df.summary + ' ' + df.title\n", "df_mb['txt'] = get_txt(df_mb)\n", "df_all['txt'] = get_txt(df_all)\n", "n=len(df_all); n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.makedirs(f'{PATH}trn/yes', exist_ok=True)\n", "os.makedirs(f'{PATH}val/yes', exist_ok=True)\n", "os.makedirs(f'{PATH}trn/no', exist_ok=True)\n", "os.makedirs(f'{PATH}val/no', exist_ok=True)\n", "os.makedirs(f'{PATH}all/trn', exist_ok=True)\n", "os.makedirs(f'{PATH}all/val', exist_ok=True)\n", "os.makedirs(f'{PATH}models', exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for (i,(_,r)) in enumerate(df_all.iterrows()):\n", " dset = 'trn' if random.random()>0.1 else 'val'\n", " open(f'{PATH}all/{dset}/{i}.txt', 'w').write(r['txt'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for (i,(_,r)) in enumerate(df_mb.iterrows()):\n", " lbl = 'yes' if r.tweeted else 'no'\n", " dset = 'trn' if random.random()>0.1 else 'val'\n", " open(f'{PATH}{dset}/{lbl}/{i}.txt', 'w').write(r['txt'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from spacy.symbols import ORTH\n", "\n", "# install the 'en' model if the next line of code fails by running:\n", "#python -m spacy download en # default English model (~50MB)\n", "#python -m spacy download en_core_web_md # larger English model (~1GB)\n", "my_tok = spacy.load('en')\n", "\n", "my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])\n", "my_tok.tokenizer.add_special_case('<CAT>', [{ORTH: '<CAT>'}])\n", "my_tok.tokenizer.add_special_case('<TITLE>', [{ORTH: '<TITLE>'}])\n", "my_tok.tokenizer.add_special_case('<BR />', [{ORTH: '<BR />'}])\n", "my_tok.tokenizer.add_special_case('<BR>', [{ORTH: '<BR>'}])\n", "\n", "def my_spacy_tok(x): return [tok.text for tok in my_tok.tokenizer(x)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TEXT = data.Field(lower=True, tokenize=my_spacy_tok)\n", "FILES = dict(train='trn', validation='val', test='val')\n", "md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n", "pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2129, 17951, 1, 9543290)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['<unk>', '<pad>', '\\n', 'the', ',', '.', 'of', '-', 'and', 'a', 'to', 'in']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "TEXT.vocab.itos[:12]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'<cat> csni <summ> the exploitation of mm - wave bands is one of the key - enabler for 5 g mobile \\n radio networks . however , the introduction of mm - wave technologies in cellular \\n networks is not straightforward due to harsh propagation conditions that limit \\n the mm - wave access availability . mm - wave technologies require high - gain antenna \\n systems to compensate for high path loss and limited power . as a consequence , \\n directional transmissions must be used for cell discovery and synchronization \\n processes : this can lead to a non - negligible access delay caused by the \\n exploration of the cell area with multiple transmissions along different \\n directions . \\n the integration of mm - wave technologies and conventional wireless access \\n networks with the objective of speeding up the cell search process requires new \\n'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(md.trn_ds[0].text[:150])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "em_sz = 200\n", "nh = 500\n", "nl = 3\n", "opt_fn = partial(optim.Adam, betas=(0.7, 0.99))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner = md.get_model(opt_fn, em_sz, nh, nl,\n", " dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)\n", "# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5\n", "# dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)\n", "learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n", "learner.clip=0.3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "76e5d02069da4309a92517b46519b2c5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 4.36189 4.2185 ] \n", "\n" ] } ], "source": [ "learner.fit(3e-3, 1, wds=1e-6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "040b60fa7679477caea6460046ba9577", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 4.11236 3.99207] \n", "[ 1. 4.03207 3.89298] \n", "[ 2. 3.91653 3.81915] \n", "[ 3. 3.97808 3.8428 ] \n", "[ 4. 3.88482 3.76226] \n", "[ 5. 3.79955 3.70472] \n", "[ 6. 3.75721 3.69048] \n", "\n" ] } ], "source": [ "learner.fit(3e-3, 3, wds=1e-6, cycle_len=1, cycle_mult=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save_encoder('adam2_enc')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21b13ef5eb7f4ba6bddb0059f5ec85d0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 3.89388 3.76575] \n", "[ 1. 3.82548 3.71875] \n", "[ 2. 3.76471 3.66974] \n", "[ 3. 3.71713 3.63861] \n", "[ 4. 3.67534 3.62983] \n", "[ 5. 3.83938 3.71551] \n", "[ 6. 3.78093 3.68056] \n", "[ 7. 3.72828 3.63638] \n", "[ 8. 3.66743 3.60355] \n", "[ 9. 3.65793 3.59448] \n", "[ 10. 3.80545 3.68213] \n", "[ 11. 3.75299 3.65219] \n", "[ 12. 3.7057 3.61324] \n", "[ 13. 3.63348 3.58048] \n", "[ 14. 3.62304 3.57257] \n", "[ 15. 3.78656 3.66324] \n", "[ 16. 3.73522 3.63348] \n", "[ 17. 3.67258 3.59369] \n", "[ 18. 3.6242 3.56674] \n", "[ 19. 3.61123 3.55783] \n", "[ 20. 3.77443 3.65472] \n", "[ 21. 3.7374 3.62169] \n", "[ 22. 3.68367 3.58247] \n", "[ 23. 3.61606 3.55567] \n", "[ 24. 3.58527 3.54725] \n", "[ 25. 3.75456 3.63861] \n", "[ 26. 3.72061 3.61084] \n", "[ 27. 3.65141 3.57073] \n", "[ 28. 3.59711 3.54414] \n", "[ 29. 3.57052 3.53622] \n", "[ 30. 3.75229 3.62935] \n", "[ 31. 3.70693 3.60198] \n", "[ 32. 3.65193 3.56444] \n", "[ 33. 3.59173 3.53742] \n", "[ 34. 3.58699 3.53152] \n", "[ 35. 3.74211 3.62154] \n", "[ 36. 3.70016 3.59831] \n", "[ 37. 3.64095 3.55689] \n", "[ 38. 3.60686 3.53296] \n", "[ 39. 3.5627 3.523 ] \n", "[ 40. 3.72944 3.61956] \n", "[ 41. 3.68161 3.58779] \n", "[ 42. 3.62305 3.55187] \n", "[ 43. 3.58559 3.52524] \n", "[ 44. 3.56087 3.51682] \n", "[ 45. 3.72533 3.61458] \n", "[ 46. 3.68025 3.58452] \n", "[ 47. 3.64447 3.55002] \n", "[ 48. 3.575 3.52066] \n", "[ 49. 3.54424 3.5133 ] \n", "\n" ] } ], "source": [ "learner.fit(3e-3, 10, wds=1e-6, cycle_len=5, cycle_save_name='adam3_10')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save_encoder('adam3_10_enc')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9df7476b9cc14a5892bc72f30f34e856", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 3.70587 3.61666] \n", "[ 1. 3.71738 3.61174] \n", "[ 2. 3.68606 3.59661] \n", "[ 3. 3.65407 3.5742 ] \n", "[ 4. 3.62901 3.55795] \n", "[ 5. 3.59921 3.53632] \n", "[ 6. 3.58401 3.52149] \n", "[ 7. 3.55126 3.50797] \n", "[ 8. 3.52965 3.50178] \n", "[ 9. 3.52336 3.49997] \n", "[ 10. 3.7109 3.60817] \n", "[ 11. 3.69879 3.60047] \n", "[ 12. 3.6735 3.58623] \n", "[ 13. 3.64365 3.56568] \n", "[ 14. 3.6099 3.54776] \n", "[ 15. 3.58244 3.52829] \n", "[ 16. 3.54894 3.51071] \n", "[ 17. 3.52702 3.50173] \n", "[ 18. 3.51357 3.49522] \n", "[ 19. 3.50302 3.49272] \n", "[ 20. 3.72505 3.60198] \n", "[ 21. 3.70037 3.59914] \n", "[ 22. 3.68386 3.58279] \n", "[ 23. 3.64176 3.56435] \n", "[ 24. 3.60259 3.54304] \n", "[ 25. 3.58669 3.52432] \n", "[ 26. 3.54075 3.50703] \n", "[ 27. 3.50951 3.49534] \n", "[ 28. 3.51915 3.4896 ] \n", "[ 29. 3.48695 3.48968] \n", "[ 30. 3.70563 3.59631] \n", "[ 31. 3.68822 3.58723] \n", "[ 32. 3.67549 3.58141] \n", "[ 33. 3.63267 3.55537] \n", "[ 34. 3.60638 3.5386 ] \n", "[ 35. 3.58803 3.52154] \n", "[ 36. 3.53987 3.50394] \n", "[ 37. 3.51036 3.49244] \n", "[ 38. 3.48651 3.48652] \n", "[ 39. 3.49061 3.48673] \n", "[ 40. 3.70093 3.59211] \n", "[ 41. 3.67371 3.58516] \n", "[ 42. 3.66558 3.57032] \n", "[ 43. 3.65089 3.55939] \n", "[ 44. 3.59885 3.53445] \n", "[ 45. 3.56369 3.51585] \n", "[ 46. 3.55304 3.50237] \n", "[ 47. 3.50469 3.48919] \n", "[ 48. 3.49559 3.48289] \n", "[ 49. 3.50912 3.48136] \n", "[ 50. 3.70603 3.59182] \n", "[ 51. 3.669 3.58069] \n", "[ 52. 3.64965 3.56896] \n", "[ 53. 3.62839 3.55251] \n", "[ 54. 3.59578 3.53297] \n", "[ 55. 3.55814 3.51205] \n", "[ 56. 3.53653 3.49682] \n", "[ 57. 3.50043 3.48502] \n", "[ 58. 3.49535 3.4797 ] \n", "[ 59. 3.48039 3.47882] \n", "[ 60. 3.68319 3.58874] \n", "[ 61. 3.68893 3.58173] \n", "[ 62. 3.6516 3.56403] \n", "[ 63. 3.63432 3.55047] \n", "[ 64. 3.59697 3.52815] \n", "[ 65. 3.55784 3.50832] \n", "[ 66. 3.52815 3.49319] \n", "[ 67. 3.50618 3.48222] \n", "[ 68. 3.48319 3.47645] \n", "[ 69. 3.49879 3.47596] \n", "[ 70. 3.68466 3.58318] \n", "[ 71. 3.67045 3.57351] \n", "[ 72. 3.64409 3.5606 ] \n", "[ 73. 3.61991 3.54552] \n", "[ 74. 3.60503 3.52782] \n", "[ 75. 3.56681 3.50743] \n", "[ 76. 3.52401 3.49046] \n", "[ 77. 3.50519 3.47875] \n", "[ 78. 3.49343 3.47452] \n", "[ 79. 3.49275 3.47175] \n", "\n" ] } ], "source": [ "learner.fit(3e-3, 8, wds=1e-6, cycle_len=10, cycle_save_name='adam3_5')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "717717efd1e441029d5618ed0be24416", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 3.47841 3.4751 ] \n", "[ 1. 3.69717 3.57883] \n", "[ 2. 3.68267 3.57793] \n", "[ 3. 3.66797 3.57299] \n", "[ 4. 3.66805 3.56847] \n", "[ 5. 3.63489 3.56238] \n", "[ 6. 3.62479 3.54928] \n", "[ 7. 3.60663 3.53879] \n", "[ 8. 3.59124 3.53175] \n", "[ 9. 3.58617 3.52009] \n", "[ 10. 3.56924 3.51174] \n", "[ 11. 3.5509 3.49974] \n", "[ 12. 3.51595 3.49008] \n", "[ 13. 3.50939 3.48222] \n", "[ 14. 3.48886 3.47952] \n", "[ 15. 3.4676 3.47311] \n", "[ 16. 3.4856 3.46577] \n", "[ 17. 3.44909 3.46499] \n", "[ 18. 3.46791 3.46314] \n", "[ 19. 3.44028 3.46231] \n", "\n" ] } ], "source": [ "learner.fit(3e-3, 1, wds=1e-6, cycle_len=20, cycle_save_name='adam3_20')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save_encoder('adam3_20_enc')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save('adam3_20')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def proc_str(s): return TEXT.preprocess(TEXT.tokenize(s))\n", "def num_str(s): return TEXT.numericalize([proc_str(s)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m=learner.model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s=\"\"\"<CAT> cscv <SUMM> algorithms that\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sample_model(m, s, l=50):\n", " t = num_str(s)\n", " m[0].bs=1\n", " m.eval()\n", " m.reset()\n", " res,*_ = m(t)\n", " print('...', end='')\n", "\n", " for i in range(l):\n", " n=res[-1].topk(2)[1]\n", " n = n[1] if n.data[0]==0 else n[0]\n", " word = TEXT.vocab.itos[n.data[0]]\n", " print(word, end=' ')\n", " if word=='<eos>': break\n", " res,*_ = m(n[0].unsqueeze(0))\n", "\n", " m[0].bs=bs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...use the same network as a single node are not able to \n", " achieve the same performance as the traditional network - based routing \n", " algorithms . in this paper , we propose a novel routing scheme for routing \n", " protocols in wireless networks . the proposed scheme is based ..." ] } ], "source": [ "sample_model(m,\"<CAT> csni <SUMM> algorithms that\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...use the same data to perform image classification are \n", " increasingly being used to improve the performance of image classification \n", " algorithms . in this paper , we propose a novel method for image classification \n", " using a deep convolutional neural network ( cnn ) . the proposed method is ..." ] } ], "source": [ "sample_model(m,\"<CAT> cscv <SUMM> algorithms that\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...the performance of deep learning for image classification <eos> " ] } ], "source": [ "sample_model(m,\"<CAT> cscv <SUMM> algorithms. <TITLE> on \")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...the performance of wireless networks <eos> " ] } ], "source": [ "sample_model(m,\"<CAT> csni <SUMM> algorithms. <TITLE> on \")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...a new approach to image classification <eos> " ] } ], "source": [ "sample_model(m,\"<CAT> cscv <SUMM> algorithms. <TITLE> towards \")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...a new approach to the analysis of wireless networks <eos> " ] } ], "source": [ "sample_model(m,\"<CAT> csni <SUMM> algorithms. <TITLE> towards \")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sentiment" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TEXT = pickle.load(open(f'{PATH}models/TEXT.pkl','rb'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ArxivDataset(torchtext.data.Dataset):\n", " def __init__(self, path, text_field, label_field, **kwargs):\n", " fields = [('text', text_field), ('label', label_field)]\n", " examples = []\n", " for label in ['yes', 'no']:\n", " fnames = glob(os.path.join(path, label, '*.txt'));\n", " assert fnames, f\"can't find 'yes.txt' or 'no.txt' under {path}/{label}\"\n", " for fname in fnames:\n", " with open(fname, 'r') as f: text = f.readline()\n", " examples.append(data.Example.fromlist([text, label], fields))\n", " super().__init__(examples, fields, **kwargs)\n", "\n", " @staticmethod\n", " def sort_key(ex): return len(ex.text)\n", " \n", " @classmethod\n", " def splits(cls, text_field, label_field, root='.data',\n", " train='train', test='test', **kwargs):\n", " return super().splits(\n", " root, text_field=text_field, label_field=label_field,\n", " train=train, validation=None, test=test, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ARX_LABEL = data.Field(sequential=False)\n", "splits = ArxivDataset.splits(TEXT, ARX_LABEL, PATH, train='trn', test='val')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "md2 = TextData.from_splits(PATH, splits, bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dropout=0.3, dropouti=0.4, wdrop=0.3, dropoute=0.05, dropouth=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import precision_recall_curve\n", "import matplotlib.pyplot as plt\n", "\n", "def prec_at_6(preds,targs):\n", " precision, recall, _ = precision_recall_curve(targs==2, preds[:,2])\n", " print(recall[precision>=0.6][0])\n", " return recall[precision>=0.6][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5\n", "m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl, \n", " dropout=0.1, dropouti=0.65, wdrop=0.5, dropoute=0.1, dropouth=0.3)\n", "m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)\n", "m3.clip=25." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# this notebook has a mess of some things going under 'all/' others not, so a little hack here\n", "!ln -sf ../all/models/adam3_20_enc.h5 {PATH}models/adam3_20_enc.h5\n", "m3.load_encoder(f'adam3_20_enc')\n", "lrs=np.array([1e-4,1e-3,1e-3,1e-2,3e-2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "feeaf546b6ed445d91985561124afc37", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.47654 0.44322 0.78525] \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "23ab581c7a264367be284e88b28890b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.43033 0.40192 0.80087] \n", "\n" ] } ], "source": [ "m3.freeze_to(-1)\n", "m3.fit(lrs/2, 1, metrics=[accuracy])\n", "m3.unfreeze()\n", "m3.fit(lrs, 1, metrics=[accuracy], cycle_len=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be7735a6a5354c279c4d6b8280bcda6e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.42236 0.39006 0.8194 ] \n", "[ 1. 0.39477 0.37063 0.82086] \n", "[ 2. 0.39389 0.37082 0.82449] \n", "[ 3. 0.40728 0.36999 0.82195] \n", "[ 4. 0.39308 0.3675 0.81977] \n", "[ 5. 0.38662 0.36737 0.8234 ] \n", "[ 6. 0.39259 0.36512 0.82486] \n", "[ 7. 0.38047 0.36538 0.82522] \n", "\n" ] } ], "source": [ "m3.fit(lrs, 2, metrics=[accuracy], cycle_len=4, cycle_save_name='imdb2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.659305993691\n" ] }, { "data": { "text/plain": [ "0.65930599369085174" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prec_at_6(*m3.predict_with_targs())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "57d2433c996b48a99f86329570a9885a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.38752 0.36351 0.82486] \n", "[ 1. 0.38664 0.36123 0.82558] \n", "[ 2. 0.3904 0.36098 0.82486] \n", "[ 3. 0.37319 0.36144 0.82486] \n", "[ 4. 0.38074 0.36334 0.82595] \n", "[ 5. 0.36405 0.3594 0.82413] \n", "[ 6. 0.38781 0.35914 0.82522] \n", "[ 7. 0.37722 0.357 0.82631] \n", "\n" ] } ], "source": [ "m3.fit(lrs, 4, metrics=[accuracy], cycle_len=2, cycle_save_name='imdb2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.695583596215\n" ] }, { "data": { "text/plain": [ "0.69558359621451105" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prec_at_6(*m3.predict_with_targs())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }