{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Machine Translation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Translation files" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Download dataset**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "French/English parallel texts from http://www.statmt.org/wmt15/translation-task.html . It was created by Chris Callison-Burch, who crawled millions of web pages and then used *a set of simple heuristics to transform French URLs onto English URLs (i.e. replacing \"fr\" with \"en\" and about 40 other hand-written rules), and assume that these documents are translations of each other*." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu/data\n" ] } ], "source": [ "%cd data\n", "%mkdir translate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "~20 minutes to download at 1.5 MB/s" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[#48abf9 32KiB/2.4GiB(0%) CN:5 DL:\u001b[32m73KiB\u001b[0m ETA:\u001b[33m9h34m4s\u001b[0m]\u001b[0m \n", "07/02 05:11:30 [\u001b[1;31mERROR\u001b[0m] CUID#10 - Download aborted. URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", "Exception: [AbstractCommand.cc:350] errorCode=29 URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", " -> [HttpSkipResponseCommand.cc:224] errorCode=29 The response status is not successful. status=503\n", "\n", "07/02 05:11:30 [\u001b[1;31mERROR\u001b[0m] CUID#8 - Download aborted. URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", "Exception: [AbstractCommand.cc:350] errorCode=29 URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", " -> [HttpSkipResponseCommand.cc:224] errorCode=29 The response status is not successful. status=503\n", "\n", "07/02 05:11:30 [\u001b[1;31mERROR\u001b[0m] CUID#7 - Download aborted. URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", "Exception: [AbstractCommand.cc:350] errorCode=29 URI=http://www.statmt.org/wmt10/training-giga-fren.tar\n", " -> [HttpSkipResponseCommand.cc:224] errorCode=29 The response status is not successful. status=503\n", " *** Download Progress Summary as of Mon Jul 2 05:12:30 2018 *** \n", "===============================================================================\n", "[#48abf9 95MiB/2.4GiB(3%) CN:2 DL:1.6MiB ETA:24m43s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:13:30 2018 *** \n", "===============================================================================\n", "[#48abf9 206MiB/2.4GiB(8%) CN:2 DL:1.9MiB ETA:19m40s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:14:31 2018 *** \n", "===============================================================================\n", "[#48abf9 322MiB/2.4GiB(13%) CN:2 DL:1.9MiB ETA:18m41s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:15:31 2018 *** \n", "===============================================================================\n", "[#48abf9 439MiB/2.4GiB(17%) CN:2 DL:1.9MiB ETA:17m40s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:16:32 2018 *** \n", "===============================================================================\n", "[#48abf9 576MiB/2.4GiB(23%) CN:2 DL:2.3MiB ETA:13m14s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:17:32 2018 *** \n", "===============================================================================\n", "[#48abf9 721MiB/2.4GiB(29%) CN:2 DL:2.3MiB ETA:12m12s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:18:33 2018 *** \n", "===============================================================================\n", "[#48abf9 866MiB/2.4GiB(34%) CN:2 DL:2.3MiB ETA:11m13s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:19:34 2018 *** \n", "===============================================================================\n", "[#48abf9 0.9GiB/2.4GiB(40%) CN:2 DL:2.3MiB ETA:10m12s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:20:34 2018 *** \n", "===============================================================================\n", "[#48abf9 1.1GiB/2.4GiB(46%) CN:2 DL:2.3MiB ETA:9m11s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:21:35 2018 *** \n", "===============================================================================\n", "[#48abf9 1.2GiB/2.4GiB(50%) CN:2 DL:1.2MiB ETA:15m59s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:22:35 2018 *** \n", "===============================================================================\n", "[#48abf9 1.2GiB/2.4GiB(53%) CN:2 DL:1.0MiB ETA:17m51s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:23:36 2018 *** \n", "===============================================================================\n", "[#48abf9 1.3GiB/2.4GiB(55%) CN:2 DL:917KiB ETA:20m17s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:24:37 2018 *** \n", "===============================================================================\n", "[#48abf9 1.4GiB/2.4GiB(58%) CN:2 DL:901KiB ETA:19m38s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:25:37 2018 *** \n", "===============================================================================\n", "[#48abf9 1.4GiB/2.4GiB(60%) CN:2 DL:901KiB ETA:18m38s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:26:38 2018 *** \n", "===============================================================================\n", "[#48abf9 1.5GiB/2.4GiB(62%) CN:2 DL:1.3MiB ETA:11m23s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:27:39 2018 *** \n", "===============================================================================\n", "[#48abf9 1.5GiB/2.4GiB(66%) CN:2 DL:1.3MiB ETA:10m7s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:28:40 2018 *** \n", "===============================================================================\n", "[#48abf9 1.6GiB/2.4GiB(70%) CN:2 DL:1.9MiB ETA:6m24s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:29:40 2018 *** \n", "===============================================================================\n", "[#48abf9 1.8GiB/2.4GiB(74%) CN:2 DL:1.9MiB ETA:5m24s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:30:41 2018 *** \n", "===============================================================================\n", "[#48abf9 1.9GiB/2.4GiB(79%) CN:2 DL:1.9MiB ETA:4m23s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:31:41 2018 *** \n", "===============================================================================\n", "[#48abf9 2.0GiB/2.4GiB(84%) CN:2 DL:1.8MiB ETA:3m29s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:32:42 2018 *** \n", "===============================================================================\n", "[#48abf9 2.1GiB/2.4GiB(88%) CN:2 DL:1.6MiB ETA:3m3s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " *** Download Progress Summary as of Mon Jul 2 05:33:42 2018 *** \n", "===============================================================================\n", "[#48abf9 2.2GiB/2.4GiB(92%) CN:2 DL:1.6MiB ETA:2m3s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:34:43 2018 *** \n", "===============================================================================\n", "[#48abf9 2.3GiB/2.4GiB(95%) CN:2 DL:1.5MiB ETA:1m5s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 05:35:44 2018 *** \n", "===============================================================================\n", "[#48abf9 2.4GiB/2.4GiB(99%) CN:2 DL:1.7MiB ETA:1s]\n", "FILE: /home/ubuntu/data/training-giga-fren.tar\n", "-------------------------------------------------------------------------------\n", "\n", "[#48abf9 2.4GiB/2.4GiB(99%) CN:2 DL:\u001b[32m1.7MiB\u001b[0m]\u001b[0m \n", "07/02 05:35:45 [\u001b[1;32mNOTICE\u001b[0m] Download complete: /home/ubuntu/data/training-giga-fren.tar\n", "\n", "Download Results:\n", "gid |stat|avg speed |path/URI\n", "======+====+===========+=======================================================\n", "48abf9|\u001b[1;32mOK\u001b[0m | 1.7MiB/s|/home/ubuntu/data/training-giga-fren.tar\n", "\n", "Status Legend:\n", "(OK):download completed.\n" ] } ], "source": [ "!aria2c --file-allocation=none -c -x 5 -s 5 http://www.statmt.org/wmt10/training-giga-fren.tar" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "!tar -xf training-giga-fren.tar" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "%mv giga-fren.release2.fixed.en.gz giga-fren.release2.fixed.fr.gz training-giga-fren.tar translate/" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu/data/translate\n" ] } ], "source": [ "%cd translate/" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tar: This does not look like a tar archive\n", "tar: Skipping to next header\n", "tar: Exiting with failure status due to previous errors\n" ] } ], "source": [ "# Strange error\n", "!tar -xzf giga-fren.release2.fixed.en.gz" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Resolve the previous issue\n", "!gunzip giga-fren.release2.fixed.en.gz" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "!gunzip giga-fren.release2.fixed.fr.gz" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu\n" ] } ], "source": [ "%cd ../.." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Setup the directories and files**" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/translate')\n", "TMP_PATH = PATH / 'tmp'\n", "TMP_PATH.mkdir(exist_ok=True)\n", "fname = 'giga-fren.release2.fixed'\n", "en_fname = PATH / f'{fname}.en'\n", "fr_fname = PATH / f'{fname}.fr'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenizing and Pre-processing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training a neural model takes a long time\n", "\n", "- Google's model has 8 layers\n", "- we are going to build a simpler one\n", "- Instead of a general model we will translate French questions" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# Question regex search filters\n", "re_eq = re.compile('^(Wh[^?.!]+\\?)')\n", "re_fq = re.compile('^([^?.!]+\\?)')\n", "\n", "# grabbing lines from the English and French source texts\n", "lines = ( (re_eq.search(eq), re_fq.search(fq))\n", " for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))\n", "\n", "# isolate the questions\n", "qs = [(e.group(), f.group()) for e, f in lines if e and f]" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "# save the questions for later\n", "pickle.dump(qs, (PATH / 'fr-en-qs.pkl').open('wb'))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# load in pickled questions\n", "qs = pickle.load((PATH / 'fr-en-qs.pkl').open('rb'))" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "52331\n", "[('What is light ?', 'Qu’est-ce que la lumière?'), ('Who are we?', 'Où sommes-nous?'), ('Where did we come from?', \"D'où venons-nous?\"), ('What would we do without it?', 'Que ferions-nous sans elle ?'), ('What is the absolute location (latitude and longitude) of Badger, Newfoundland and Labrador?', 'Quelle sont les coordonnées (latitude et longitude) de Badger, à Terre-Neuve-etLabrador?')]\n" ] } ], "source": [ "# ======================================== START DEBUG ========================================\n", "print(len(qs))\n", "print(qs[:5])\n", "# ======================================== END DEBUG ========================================" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('x', 3), ('y', 4), ('z', 5)]\n", "('x', 'y', 'z')\n", "(3, 4, 5)\n" ] } ], "source": [ "# ======================================== START DEBUG ========================================\n", "# Python zip method: https://www.programiz.com/python-programming/methods/built-in/zip\n", "# What is iterable, iterator: https://stackoverflow.com/questions/9884132/what-exactly-are-iterator-iterable-and-iteration\n", "coord = ['x', 'y', 'z']\n", "value = [3, 4, 5, 0, 9]\n", "result = zip(coord, value)\n", "result_list = list(result)\n", "print(result_list)\n", "# unzip result_list\n", "c, v = zip(*result_list)\n", "print(c)\n", "print(v)\n", "# ======================================== END DEBUG ========================================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tokenize all the questions." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "en_qs, fr_qs = zip(*qs)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "en_tok = Tokenizer.proc_all_mp(partition_by_cores(en_qs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Note: tokenizing for French is much different compared to english_" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-2.0.0/fr_core_news_sm-2.0.0.tar.gz\n", " Downloading https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-2.0.0/fr_core_news_sm-2.0.0.tar.gz (39.8MB)\n", "\u001b[K 100% |████████████████████████████████| 39.8MB 71.0MB/s ta 0:00:01\n", "\u001b[?25hInstalling collected packages: fr-core-news-sm\n", " Running setup.py install for fr-core-news-sm ... \u001b[?25ldone\n", "\u001b[?25hSuccessfully installed fr-core-news-sm-2.0.0\n", "\u001b[33mYou are using pip version 9.0.3, however version 10.0.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n", "\n", "\u001b[93m Linking successful\u001b[0m\n", " /home/ubuntu/anaconda3/envs/fastai/lib/python3.6/site-packages/fr_core_news_sm\n", " -->\n", " /home/ubuntu/anaconda3/envs/fastai/lib/python3.6/site-packages/spacy/data/fr\n", "\n", " You can now load the model via spacy.load('fr')\n", "\n" ] } ], "source": [ "# Download spaCy 'fr' model.Otherwise, you'll encounter errorr \"OSError: [E050] Can't find model 'fr'...\"\n", "!python -m spacy download fr" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "scrolled": true }, "outputs": [], "source": [ "fr_tok = Tokenizer.proc_all_mp(partition_by_cores(fr_qs), 'fr')" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([['what', 'is', 'light', '?'],\n", " ['who', 'are', 'we', '?'],\n", " ['where', 'did', 'we', 'come', 'from', '?']],\n", " [['qu’', 'est', '-ce', 'que', 'la', 'lumière', '?'],\n", " ['où', 'sommes', '-', 'nous', '?'],\n", " [\"d'\", 'où', 'venons', '-', 'nous', '?']])" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en_tok[:3], fr_tok[:3]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check stats for the sentences length" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(23.0, 28.0)" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 90th percentile of English and French sentences length.\n", "np.percentile([len(o) for o in en_tok], 90), np.percentile([len(o) for o in fr_tok], 90)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are keeping tokens that are less than 30 chars. The filter is applied on the English words, and the same tokens are kept for French." ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "keep = np.array([len(o) < 30 for o in en_tok])" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "en_tok = np.array(en_tok)[keep]\n", "fr_tok = np.array(fr_tok)[keep]" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "# save our work\n", "pickle.dump(en_tok, (PATH / 'en_tok.pkl').open('wb'))\n", "pickle.dump(fr_tok, (PATH / 'fr_tok.pkl').open('wb'))" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [], "source": [ "def toks2ids(tok, pre):\n", " \"\"\"\n", " Numericalize words to integers.\n", " \n", " Arguments:\n", " tok: token\n", " pre: prefix\n", " \"\"\"\n", " freq = Counter(p for o in tok for p in o)\n", " itos = [o for o, c in freq.most_common(40000)] # 40k most common words\n", " itos.insert(0, '_bos_')\n", " itos.insert(1, '_pad_')\n", " itos.insert(2, '_eos_')\n", " itos.insert(3, '_unk')\n", " stoi = collections.defaultdict(lambda: 3, { v: k for k, v in enumerate(itos) }) #reverse\n", " ids = np.array([ ([stoi[o] for o in p] + [2]) for p in tok ])\n", " np.save(TMP_PATH / f'{pre}_ids.npy', ids)\n", " pickle.dump(itos, open(TMP_PATH / f'{pre}_itos.pkl', 'wb'))\n", " return ids, itos, stoi" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "en_ids, en_itos, en_stoi = toks2ids(en_tok, 'en')\n", "fr_ids, fr_itos, fr_stoi = toks2ids(fr_tok, 'fr')" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [], "source": [ "def load_ids(pre):\n", " ids = np.load(TMP_PATH / f'{pre}_ids.npy')\n", " itos = pickle.load(open(TMP_PATH / f'{pre}_itos.pkl', 'rb'))\n", " stoi = collections.defaultdict(lambda: 3, { v: k for k, v in enumerate(itos) })\n", " return ids, itos, stoi" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [], "source": [ "en_ids, en_itos, en_stoi = load_ids('en')\n", "fr_ids, fr_itos, fr_stoi = load_ids('fr')" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(['qu’', 'est', '-ce', 'que', 'la', 'lumière', '?', '_eos_'], 17573, 24793)" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sanity check\n", "[fr_itos[o] for o in fr_ids[0]], len(en_itos), len(fr_itos)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Word vectors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Facebook's fasttext word vectors available from https://fasttext.cc/docs/en/english-vectors.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download word vectors:\n", "\n", "We are using the pre-trained word vectors for English language, trained on Wikipedia using fastText. These vectors in dimension 300 were obtained using the skip-gram model: https://fasttext.cc/docs/en/pretrained-vectors.html" ] }, { "cell_type": "code", "execution_count": 118, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " *** Download Progress Summary as of Mon Jul 2 14:55:13 2018 *** \n", "===============================================================================\n", "[#fe81f7 2.9GiB/9.6GiB(30%) CN:5 DL:55MiB ETA:2m5s]\n", "FILE: data/translate/wiki.en.zip\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 14:56:14 2018 *** \n", "===============================================================================\n", "[#fe81f7 5.8GiB/9.6GiB(60%) CN:5 DL:52MiB ETA:1m14s]\n", "FILE: data/translate/wiki.en.zip\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 14:57:14 2018 *** \n", "===============================================================================\n", "[#fe81f7 8.6GiB/9.6GiB(89%) CN:5 DL:52MiB ETA:19s]\n", "FILE: data/translate/wiki.en.zip\n", "-------------------------------------------------------------------------------\n", "\n", "[#fe81f7 9.6GiB/9.6GiB(99%) CN:2 DL:\u001b[32m39MiB\u001b[0m]\u001b[0m \n", "07/02 14:57:39 [\u001b[1;32mNOTICE\u001b[0m] Download complete: data/translate/wiki.en.zip\n", "\n", "Download Results:\n", "gid |stat|avg speed |path/URI\n", "======+====+===========+=======================================================\n", "fe81f7|\u001b[1;32mOK\u001b[0m | 47MiB/s|data/translate/wiki.en.zip\n", "\n", "Status Legend:\n", "(OK):download completed.\n" ] } ], "source": [ "!aria2c --file-allocation=none -c -x 5 -s 5 -d data/translate https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.zip" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " *** Download Progress Summary as of Mon Jul 2 15:01:29 2018 *** \n", "===============================================================================\n", "[#2e6d55 2.5GiB/5.5GiB(45%) CN:5 DL:34MiB ETA:1m29s]\n", "FILE: data/translate/wiki.fr.zip\n", "-------------------------------------------------------------------------------\n", "\n", " *** Download Progress Summary as of Mon Jul 2 15:02:29 2018 *** \n", "===============================================================================\n", "[#2e6d55 5.3GiB/5.5GiB(95%) CN:5 DL:42MiB ETA:5s]\n", "FILE: data/translate/wiki.fr.zip\n", "-------------------------------------------------------------------------------\n", "\n", "[#2e6d55 5.5GiB/5.5GiB(99%) CN:1 DL:\u001b[32m11MiB\u001b[0m]\u001b[0m \n", "07/02 15:02:47 [\u001b[1;32mNOTICE\u001b[0m] Download complete: data/translate/wiki.fr.zip\n", "\n", "Download Results:\n", "gid |stat|avg speed |path/URI\n", "======+====+===========+=======================================================\n", "2e6d55|\u001b[1;32mOK\u001b[0m | 41MiB/s|data/translate/wiki.fr.zip\n", "\n", "Status Legend:\n", "(OK):download completed.\n" ] } ], "source": [ "!aria2c --file-allocation=none -c -x 5 -s 5 -d data/translate https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.fr.zip" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Archive: data/translate/wiki.en.zip\n", " inflating: data/translate/wiki.en.vec \n", " inflating: data/translate/wiki.en.bin \n" ] } ], "source": [ "!unzip data/translate/wiki.en.zip -d data/translate/" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Archive: data/translate/wiki.fr.zip\n", " inflating: data/translate/wiki.fr.vec \n", " inflating: data/translate/wiki.fr.bin \n" ] } ], "source": [ "!unzip data/translate/wiki.fr.zip -d data/translate/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use the fastText library, you'll need to download [fasttext word vectors](https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md) for your language (download the 'bin plus text' ones)." ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/facebookresearch/fastText.git\n", " Cloning https://github.com/facebookresearch/fastText.git to /tmp/pip-9tychtgg-build\n", "Collecting pybind11>=2.2 (from fasttext==0.8.22)\n", " Downloading https://files.pythonhosted.org/packages/12/90/0f92a575dc60c8fba6d0c91d6b45abdb1058da9ebed40400cbcfad2ac0a7/pybind11-2.2.3-py2.py3-none-any.whl (144kB)\n", "\u001b[K 100% |████████████████████████████████| 153kB 1.8MB/s ta 0:00:01\n", "\u001b[?25hRequirement already satisfied: setuptools>=0.7.0 in ./anaconda3/envs/fastai/lib/python3.6/site-packages (from fasttext==0.8.22)\n", "Requirement already satisfied: numpy in ./anaconda3/envs/fastai/lib/python3.6/site-packages (from fasttext==0.8.22)\n", "Installing collected packages: pybind11, fasttext\n", " Running setup.py install for fasttext ... \u001b[?25ldone\n", "\u001b[?25hSuccessfully installed fasttext-0.8.22 pybind11-2.2.3\n", "\u001b[33mYou are using pip version 9.0.3, however version 10.0.1 is available.\n", "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" ] } ], "source": [ "!pip install git+https://github.com/facebookresearch/fastText.git" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "import fastText as ft" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [], "source": [ "en_vecs = ft.load_model(str((PATH / 'wiki.en.bin')))" ] }, { "cell_type": "code", "execution_count": 128, "metadata": {}, "outputs": [], "source": [ "fr_vecs = ft.load_model(str((PATH / 'wiki.fr.bin')))" ] }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [], "source": [ "def get_vecs(lang, ft_vecs):\n", " \"\"\"\n", " Convert fastText word vectors into a standard Python dictionary to make it a bit easier to work with.\n", " This is just going through each word with a dictionary comprehension and save it as a pickle dictionary.\n", " \n", " get_word_vector:\n", " [method] get the vector representation of word.\n", " get_words:\n", " [method] get the entire list of words of the dictionary optionally\n", " including the frequency of the individual words. This\n", " does not include any subwords. \n", " \"\"\"\n", " vecd = { w: ft_vecs.get_word_vector(w) for w in ft_vecs.get_words() }\n", " pickle.dump(vecd, open(PATH / f'wiki.{lang}.pkl', 'wb'))\n", " return vecd" ] }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [], "source": [ "en_vecd = get_vecs('en', en_vecs)\n", "fr_vecd = get_vecs('fr', fr_vecs)" ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "en_vecd = pickle.load(open(PATH / 'wiki.en.pkl', 'rb'))\n", "fr_vecd = pickle.load(open(PATH / 'wiki.fr.pkl', 'rb'))" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [], "source": [ "# DEBUG\n", "ft_vecs = en_vecs" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2519370" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# DEBUG\n", "ft_words = ft_vecs.get_words(include_freq=True)\n", "ft_word_dict = { k: v for k, v in zip(*ft_words) }\n", "ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x])\n", "len(ft_words)" ] }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(300, 300)" ] }, "execution_count": 149, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dim_en_vec = len(en_vecd[','])\n", "dim_fr_vec = len(fr_vecd[','])\n", "dim_en_vec, dim_fr_vec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Find out what the mean and standard deviation of our vectors are. So the mean is about zero and standard deviation is about 0.3." ] }, { "cell_type": "code", "execution_count": 179, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.0075652334, 0.29283327)" ] }, "execution_count": 179, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# en_vecd type is dict\n", "en_vecs = np.stack(list(en_vecd.values())) # convert dict_values to list and then stack it\n", "en_vecs.mean(), en_vecs.std()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exclude the extreme cases**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Often corpuses have a pretty long tailed distribution of sequence length and it's the longest sequences that tend to overwhelm how long things take, how much memory is used, etc. So in this case, we are going to grab 99th to 97th percentile of the English and French and truncate them to that amount. Originally Jeremy was using 90 percentiles (hence the variable name):" ] }, { "cell_type": "code", "execution_count": 157, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(29, 38)" ] }, "execution_count": 157, "metadata": {}, "output_type": "execute_result" } ], "source": [ "enlen_90 = int(np.percentile([len(o) for o in en_ids], 99))\n", "frlen_90 = int(np.percentile([len(o) for o in fr_ids], 99))\n", "enlen_90, frlen_90" ] }, { "cell_type": "code", "execution_count": 158, "metadata": {}, "outputs": [], "source": [ "en_ids_tr = np.array([o[:enlen_90] for o in en_ids])\n", "fr_ids_tr = np.array([o[:frlen_90] for o in fr_ids])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Create our Dataset, DataLoaders**" ] }, { "cell_type": "code", "execution_count": 159, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqDataset(Dataset):\n", " def __init__(self, x, y):\n", " self.x, self.y = x, y\n", " \n", " def __getitem__(self, idx):\n", " return A(self.x[idx], self.y[idx]) # A for Arrays\n", " \n", " def __len__(self):\n", " return len(self.x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Split the training and testing set**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an easy way to get training and validation sets. Grab a bunch of random numbers — one for each row of your data, and see if they are bigger than 0.1 or not. That gets you a list of booleans. Index into your array with that list of booleans to grab a training set, index into that array with the opposite of that list of booleans to get your validation set." ] }, { "cell_type": "code", "execution_count": 160, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(45219, 5041)" ] }, "execution_count": 160, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.random.seed(42)\n", "trn_keep = np.random.rand(len(en_ids_tr)) > 0.1\n", "en_trn, fr_trn = en_ids_tr[trn_keep], fr_ids_tr[trn_keep] # training set\n", "en_val, fr_val = en_ids_tr[~trn_keep], fr_ids_tr[~trn_keep] # validation set\n", "len(en_trn), len(en_val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Create training and validation sets**" ] }, { "cell_type": "code", "execution_count": 162, "metadata": {}, "outputs": [], "source": [ "trn_ds = Seq2SeqDataset(fr_trn, en_trn)\n", "val_ds = Seq2SeqDataset(fr_val, en_val)" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [], "source": [ "# Set batch size\n", "bs = 125" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Most of our preprocessing is complete, so making `numworkers = 1` will save you some time.\n", "- Padding will pad the shorter phrases to be the same length.\n", "- Classifier → padding in the beginning.\n", "- Decoder → padding at the end.\n", "- Sampler - so we keep the similar sentences together (sorted by length)." ] }, { "cell_type": "code", "execution_count": 165, "metadata": {}, "outputs": [], "source": [ "# arranges sentences so that similar lengths are close to each other\n", "trn_samp = SortishSampler(en_trn, key=lambda x: len(en_trn[x]), bs=bs)\n", "val_samp = SortSampler(en_val, key=lambda x: len(en_val[x]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Create DataLoaders**" ] }, { "cell_type": "code", "execution_count": 167, "metadata": {}, "outputs": [], "source": [ "trn_dl = DataLoader(trn_ds, bs, transpose=True, transpose_y=True, num_workers=1,\n", " pad_idx=1, pre_pad=False, sampler=trn_samp)\n", "val_dl = DataLoader(val_ds, int(bs * 1.6), transpose=True, transpose_y=True, num_workers=1,\n", " pad_idx=1, pre_pad=False, sampler=val_samp)\n", "md = ModelData(PATH, trn_dl, val_dl)" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(38, 29), (21, 7), (21, 8), (38, 13), (38, 21)]" ] }, "execution_count": 169, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test - inspect\n", "\n", "it = iter(trn_dl) # trn_dl is iterable. turns iterable into iterator.\n", "# Return the next item from the iterator.\n", "its = [next(it) for i in range(5)]\n", "[(len(x), len(y)) for x, y in its]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Architecture diagram](https://s15.postimg.cc/x710hbkdn/1_1f_KDa_Dsww_Vu3w2_Zt_Cg-_Uow.png)\n", "\n", "- The architecture is going to take our sequence of tokens.\n", "- It is going to spit them into an encoder (a.k.a. backbone).\n", "- That is going to spit out the final hidden state which for each sentence, it’s just a single vector.\n", "- Then, it will need to be passed to a decoder that will walk through the words one by one." ] }, { "cell_type": "code", "execution_count": 170, "metadata": {}, "outputs": [], "source": [ "def create_emb(vecs, itos, em_sz):\n", " \"\"\"\n", " Creates embedding:\n", " 1. rows = number of vocab\n", " 2. cols = embedding size dimension\n", " \n", " Will randomly initialize the embedding\n", " \"\"\"\n", " emb = nn.Embedding(len(itos), em_sz, padding_idx=1)\n", " wgts = emb.weight.data\n", " miss = []\n", " \n", " # goes through the embedding and replace\n", " # the initialized weights with existing word vectors\n", " # multiply x3 to compensate for the stdev 0.3\n", " for i, w in enumerate(itos):\n", " try:\n", " wgts[i] = torch.from_numpy(vecs[w] * 3)\n", " except:\n", " miss.append(w)\n", " print(len(miss), miss[5:10])\n", " return emb" ] }, { "cell_type": "code", "execution_count": 171, "metadata": {}, "outputs": [], "source": [ "nh, nl = 256, 2" ] }, { "cell_type": "code", "execution_count": 172, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN(nn.Module):\n", " def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):\n", " super().__init__()\n", " \n", " # encoder (enc)\n", " self.nl, self.nh, self.out_sl = nl, nh, out_sl\n", " \n", " # for each word, pull up the 300M vector and create an embedding\n", " self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " \n", " # GRU - similiar to LSTM\n", " self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)\n", " self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)\n", " \n", " # decoder (dec)\n", " self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)\n", " self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(em_sz_dec, len(itos_dec))\n", " self.out.weight.data = self.emb_dec.weight.data\n", " \n", " def forward(self, inp):\n", " sl, bs = inp.size()\n", "\n", " # ==================================================\n", " # Encoder version\n", " # ==================================================\n", " \n", " # initialize the hidden layer\n", " h = self.initHidden(bs)\n", " \n", " # run the input through our embeddings + apply dropout\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " \n", " # run it through the RNN layer\n", " enc_out, h = self.gru_enc(emb, h)\n", " \n", " # run the hidden state through our linear layer\n", " # note: we are only using the last hidden state to 'decode' into another phrase\n", " h = self.out_enc(h)\n", " \n", " # ==================================================\n", " # Decoder version\n", " # ==================================================\n", " \n", " # starting with a 0 (or beginning of string _BOS_)\n", " dec_inp = V(torch.zeros(bs).long())\n", " res = []\n", " \n", " # will loop as long as the longest english sentence\n", " for i in range(self.out_sl):\n", " \n", " # embedding - we are only looking at a section at time\n", " # which is why the .unsqueeze is required\n", " emb = self.emb_dec(dec_inp).unsqueeze(0)\n", " \n", " # rnn - typically works with whole phrases, but we passing\n", " # only 1 unit at a time in a loop\n", " outp, h = self.gru_dec(emb, h)\n", " \n", " # dropout\n", " outp = self.out(self.out_drop(outp[0]))\n", " \n", " res.append(outp)\n", " \n", " # highest probability word\n", " dec_inp = V(outp.data.max(1)[1])\n", " \n", " # if its padding ,we are at the end of the sentence\n", " if (dec_inp == 1).all():\n", " break\n", "\n", " # stack the output into a single tensor\n", " return torch.stack(res)\n", "\n", " def initHidden(self, bs):\n", " return V(torch.zeros(self.nl, bs, self.nh))\n", " " ] }, { "cell_type": "code", "execution_count": 173, "metadata": {}, "outputs": [], "source": [ "def seq2seq_loss(input, target):\n", " \"\"\"\n", " Loss function - modified version of cross entropy\n", " \"\"\"\n", " sl, bs = target.size()\n", " sl_in, bs_in, nc = input.size()\n", " \n", " # sequence length could be shorter than the original\n", " # need to add padding to even out the size\n", " if sl > sl_in:\n", " input = F.pad(input, (0, 0, 0, 0, 0, sl - sl_in))\n", " input = input[:sl]\n", " return F.cross_entropy(input.view(-1, nc), target.view(-1))#, ignore_index=1)" ] }, { "cell_type": "code", "execution_count": 174, "metadata": {}, "outputs": [], "source": [ "opt_fn = partial(optim.Adam, betas=(0.8, 0.99))" ] }, { "cell_type": "code", "execution_count": 175, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3097 ['l’', \"d'\", 't_up', 'd’', \"qu'\"]\n", "1285 [\"'s\", '’s', \"n't\", 'n’t', ':']\n" ] } ], "source": [ "rnn = Seq2SeqRNN(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90)\n", "learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn)\n", "learn.crit = seq2seq_loss" ] }, { "cell_type": "code", "execution_count": 176, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6e45ba3d891a43c8bc1cb451c9afa881", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 65%|██████▍ | 235/362 [01:11<00:38, 3.28it/s, loss=29.5]" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Find the learning rate\n", "learn.lr_find()\n", "learn.sched.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Fit the model (15-20 mins to train)**" ] }, { "cell_type": "code", "execution_count": 177, "metadata": {}, "outputs": [], "source": [ "lr = 3e-3" ] }, { "cell_type": "code", "execution_count": 178, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fa19dcfa8db64e78baf878c7ba7fdada", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=12), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 5.209042 5.980303 \n", " 1 4.513244 4.566818 \n", " 2 4.056711 4.515142 \n", " 3 3.775803 4.026515 \n", " 4 3.595237 3.857968 \n", " 5 3.519258 3.773164 \n", " 6 3.160189 3.705156 \n", " 7 3.108818 3.66531 \n", " 8 3.142783 3.613333 \n", " 9 3.192778 3.680305 \n", " 10 2.844773 3.637095 \n", " 11 2.857365 3.5963 \n" ] }, { "data": { "text/plain": [ "[array([3.5963])]" ] }, "execution_count": 178, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=12, use_clr=(20, 10))" ] }, { "cell_type": "code", "execution_count": 180, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [], "source": [ "learn.save('initial')" ] }, { "cell_type": "code", "execution_count": 182, "metadata": {}, "outputs": [], "source": [ "learn.load('initial')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test" ] }, { "cell_type": "code", "execution_count": 188, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quelles composantes des différents aspects de la performance devraient être mesurées , quelles données pertinentes recueillir et comment ? _eos_\n", "which components within various performance areas should be measured , whatkinds of data are appropriate to collect , and how should this be done ? _eos_\n", "what aspects of the and and be be be be be be be be be ? ? _eos_\n", "\n", "le premier ministre doit - il nommer un ministre d’ état à la santé mentale , à la maladie mentale et à la toxicomanie ? _eos_\n", "what role can the federal government play to ensure that individuals with mental illness and addiction have access to the drug therapy they need ? _eos_\n", "what minister the minister minister minister minister minister , , , , health health and health ? ? ? ? _eos_\n", "\n", "quelles sont les conséquences de la hausse des formes d’ emploi non conformes aux normes chez les travailleurs hautement qualifiés et chez ceux qui occupent des emplois plus marginaux ? _eos_\n", "what is the impact of growing forms of non - standard employment for highly skilled workers and for those employed in more marginal occupations ? _eos_\n", "what are the consequences of workers workers workers workers workers workers and and workers and workers workers workers workers workers workers ? ? ? _eos_ _eos_\n", "\n", "que se produit - il si le gestionnaire n’ est pas en mesure de donner à l’ employé nommé pour une période déterminée un préavis de cessation d’ emploi d’ un mois ou s’ il néglige de le\n", "what happens if the manager is unable to or neglects to give a term employee the one - month notice of non - renewal ? _eos_\n", "what happens the the employee employee employee employee employee the the the the the or or the the the ? ? _eos_\n", "\n", "quelles personnes , communautés ou entités sont considérées comme potentiels i ) bénéficiaires de la protection et ii ) titulaires de droits ? _eos_\n", "which persons , communities or entities are identified as potential ( i ) beneficiaries of protection and / or ( ii ) rights holders ? _eos_\n", "who , , , , , or or or or or or or or protection ? ? ? ? _eos_\n", "\n", "quelles conditions particulières doivent être remplies pendant l’ examen préliminaire international en ce qui concerne les listages des séquences de nucléotides ou d’ acides aminés ou les tableaux y relatifs ? _eos_\n", "what special requirements apply during the international preliminary examination to nucleotide and / or amino acid sequence listings and / or tables related thereto ? _eos_\n", "what specific must be be be be sequence sequence or or or or sequence or or sequence or sequence or sequence in in ? ? ? ? _eos_ _eos_\n", "\n", "pourquoi cette soudaine réticence à promouvoir l’ égalité des genres et à protéger les femmes de ce que , dans la plupart des cas , on peut qualifier de violations grossières des droits humains ? _eos_\n", "why this sudden reluctance to effectively promote gender equality and protect women from what are – in many cases – egregious human rights violations ? _eos_\n", "why is the so for such of of of of of and rights and rights rights of rights rights ? ? ? ? _eos_ _eos_\n", "\n", "pouvez - vous dire comment votre bagage culturel vous a aidée à aborder votre nouvelle vie au canada ( à vous adapter au mode de vie canadien ) ? _eos_\n", "what are some things from your cultural background that have helped you navigate canadian life ( helped you adjust to life in canada ) ? _eos_\n", "what are you new to to to to to to to to life life life life ? ? ? ? _eos_ _eos_\n", "\n", "selon vous , quels seront , dans les dix prochaines années , les cinq enjeux les plus urgents en matière d' environnement et d' avenir viable pour vous et votre région ? _eos_\n", "which do you think will be the five most pressing environmental and sustainability issues for you and your region in the next ten years ? _eos_\n", "what do you see the next priorities priorities next the next the and and in in in in in in ? ? ? ? ? _eos_ _eos_\n", "\n", "dans quelle mesure l’ expert est-il motivé et capable de partager ses connaissances , et dans quelle mesure son successeur est-il motivé et capable de recevoir ce savoir ? _eos_\n", "what is the expert ’s level of motivation and capability for sharing knowledge , and the successor ’s motivation and capability of acquiring it ? _eos_\n", "what is the nature and and and and and and and and and and and and and to to to ? ? ? ? _eos_ _eos_ _eos_\n", "\n" ] } ], "source": [ "x, y = next(iter(val_dl))\n", "probs = learn.model(V(x))\n", "preds = to_np(probs.max(2)[1])\n", "\n", "for i in range(180, 190):\n", " print(' '.join([ fr_itos[o] for o in x[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in y[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in preds[:, i] if o != 1 ]))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bi-direction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Take all your sequences and reverse them and make a \"backwards model\" then average the predictions. Note that with deeper models, not all levels may be bi-directional." ] }, { "cell_type": "code", "execution_count": 191, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_Bidir(nn.Module):\n", " def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):\n", " super().__init__()\n", " self.nl, self.nh, self.out_sl = nl, nh, out_sl\n", " self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25, bidirectional=True) # for bidir, bidirectional=True\n", " self.out_enc = nn.Linear(nh * 2, em_sz_dec, bias=False) # for bidir, nh * 2\n", " self.drop_enc = nn.Dropout(0.05) # additional for bidir\n", " \n", " self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)\n", " self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(em_sz_dec, len(itos_dec))\n", " self.out.weight.data = self.emb_dec.weight.data\n", "\n", " def forward(self, inp):\n", " sl, bs = inp.size()\n", "\n", " # ==================================================\n", " # Encoder version\n", " # ==================================================\n", "\n", " # initialize the hidden layer\n", " h = self.initHidden(bs)\n", "\n", " # run the input through our embeddings + apply dropout\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", "\n", " # run it through the RNN layer\n", " enc_out, h = self.gru_enc(emb, h)\n", "\n", " # Additional for bidir\n", " h = h.view(2, 2, bs, -1).permute(0, 2, 1, 3).contiguous().view(2, bs, -1)\n", " \n", " # run the hidden state through our linear layer\n", " h = self.out_enc(self.drop_enc(h)) # new for bidir; dropout hidden state.\n", "\n", " # ==================================================\n", " # Decoder version\n", " # ==================================================\n", "\n", " # starting with a 0 (or beginning of string _BOS_)\n", " dec_inp = V(torch.zeros(bs).long())\n", " res = []\n", "\n", " # will loop as long as the longest english sentence\n", " for i in range(self.out_sl):\n", "\n", " # embedding - we are only looking at a section at time\n", " # which is why the .unsqueeze is required\n", " emb = self.emb_dec(dec_inp).unsqueeze(0)\n", "\n", " # rnn - typically works with whole phrases, but we passing\n", " # only 1 unit at a time in a loop\n", " outp, h = self.gru_dec(emb, h)\n", "\n", " # dropout\n", " outp = self.out(self.out_drop(outp[0]))\n", "\n", " res.append(outp)\n", "\n", " # highest probability word\n", " dec_inp = V(outp.data.max(1)[1])\n", "\n", " # if its padding ,we are at the end of the sentence\n", " if (dec_inp == 1).all():\n", " break\n", "\n", " # stack the output into a single tensor\n", " return torch.stack(res)\n", "\n", " def initHidden(self, bs):\n", " return V(torch.zeros(self.nl * 2, bs, self.nh)) # for bidir, sel.nl * 2" ] }, { "cell_type": "code", "execution_count": 192, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3097 ['l’', \"d'\", 't_up', 'd’', \"qu'\"]\n", "1285 [\"'s\", '’s', \"n't\", 'n’t', ':']\n" ] } ], "source": [ "rnn = Seq2SeqRNN_Bidir(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90)\n", "learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn)\n", "learn.crit = seq2seq_loss" ] }, { "cell_type": "code", "execution_count": 193, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c58d98c9ed43427687de852a42598a7f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=12), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 4.766771 4.495123 \n", " 1 3.918195 4.018911 \n", " 2 3.682928 3.852527 \n", " 3 3.654867 3.653316 \n", " 4 3.540806 3.581977 \n", " 5 3.38937 3.518663 \n", " 6 3.337964 3.461221 \n", " 7 2.868424 3.439734 \n", " 8 2.783658 3.426322 \n", " 9 2.743709 3.375462 \n", " 10 2.662714 3.39351 \n", " 11 2.551906 3.373751 \n" ] }, { "data": { "text/plain": [ "[array([3.37375])]" ] }, "execution_count": 193, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=12, use_clr=(20, 10))" ] }, { "cell_type": "code", "execution_count": 196, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 194, "metadata": {}, "outputs": [], "source": [ "learn.save('bidir')" ] }, { "cell_type": "code", "execution_count": 195, "metadata": {}, "outputs": [], "source": [ "learn.load('bidir')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Test**" ] }, { "cell_type": "code", "execution_count": 197, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quelles composantes des différents aspects de la performance devraient être mesurées , quelles données pertinentes recueillir et comment ? _eos_\n", "which components within various performance areas should be measured , whatkinds of data are appropriate to collect , and how should this be done ? _eos_\n", "which aspects of should should should be be and and how how how be be be ? ? _eos_ _eos_\n", "\n", "le premier ministre doit - il nommer un ministre d’ état à la santé mentale , à la maladie mentale et à la toxicomanie ? _eos_\n", "what role can the federal government play to ensure that individuals with mental illness and addiction have access to the drug therapy they need ? _eos_\n", "who is the minister minister minister to minister mental mental mental mental mental health health ? ? ? _eos_\n", "\n", "quelles sont les conséquences de la hausse des formes d’ emploi non conformes aux normes chez les travailleurs hautement qualifiés et chez ceux qui occupent des emplois plus marginaux ? _eos_\n", "what is the impact of growing forms of non - standard employment for highly skilled workers and for those employed in more marginal occupations ? _eos_\n", "what are the implications of of of of of workers workers workers workers workers workers workers in less workers ? ? ? _eos_ _eos_\n", "\n", "que se produit - il si le gestionnaire n’ est pas en mesure de donner à l’ employé nommé pour une période déterminée un préavis de cessation d’ emploi d’ un mois ou s’ il néglige de le\n", "what happens if the manager is unable to or neglects to give a term employee the one - month notice of non - renewal ? _eos_\n", "what happens if the employee of the the the the of of of or or or or of of of\n", "\n", "quelles personnes , communautés ou entités sont considérées comme potentiels i ) bénéficiaires de la protection et ii ) titulaires de droits ? _eos_\n", "which persons , communities or entities are identified as potential ( i ) beneficiaries of protection and / or ( ii ) rights holders ? _eos_\n", "which communities are are or as or or or or or , , of ? ? ? ? ?\n", "\n", "quelles conditions particulières doivent être remplies pendant l’ examen préliminaire international en ce qui concerne les listages des séquences de nucléotides ou d’ acides aminés ou les tableaux y relatifs ? _eos_\n", "what special requirements apply during the international preliminary examination to nucleotide and / or amino acid sequence listings and / or tables related thereto ? _eos_\n", "what special requirements requirements be be for for for / sequence / or or sequence or or sequence sequence or or sequence sequence ? ? _eos_ _eos_\n", "\n", "pourquoi cette soudaine réticence à promouvoir l’ égalité des genres et à protéger les femmes de ce que , dans la plupart des cas , on peut qualifier de violations grossières des droits humains ? _eos_\n", "why this sudden reluctance to effectively promote gender equality and protect women from what are – in many cases – egregious human rights violations ? _eos_\n", "why is such such such women women of women , , , , rights rights ? ? ? ? ? _eos_ _eos_\n", "\n", "pouvez - vous dire comment votre bagage culturel vous a aidée à aborder votre nouvelle vie au canada ( à vous adapter au mode de vie canadien ) ? _eos_\n", "what are some things from your cultural background that have helped you navigate canadian life ( helped you adjust to life in canada ) ? _eos_\n", "what is your you to you you to to to to to life life life life in life canada ? ? ? ? _eos_\n", "\n", "selon vous , quels seront , dans les dix prochaines années , les cinq enjeux les plus urgents en matière d' environnement et d' avenir viable pour vous et votre région ? _eos_\n", "which do you think will be the five most pressing environmental and sustainability issues for you and your region in the next ten years ? _eos_\n", "what do you see the the the the the the , , future and and and and and future future future future future ? ? ? ? _eos_\n", "\n", "dans quelle mesure l’ expert est-il motivé et capable de partager ses connaissances , et dans quelle mesure son successeur est-il motivé et capable de recevoir ce savoir ? _eos_\n", "what is the expert ’s level of motivation and capability for sharing knowledge , and the successor ’s motivation and capability of acquiring it ? _eos_\n", "what is is expertise of the of and and and and and and and and and and and and and and and and ? ? ?\n", "\n" ] } ], "source": [ "x, y = next(iter(val_dl))\n", "probs = learn.model(V(x))\n", "preds = to_np(probs.max(2)[1])\n", "\n", "for i in range(180, 190):\n", " print(' '.join([ fr_itos[o] for o in x[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in y[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in preds[:, i] if o != 1 ]))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Teacher forcing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the model starts learning, it starts out not knowing anything about the different languages. It will eventually get better, but in the beginning it doesn't have a lot to work with.\n", "\n", "**idea:** what if we force feed the correct answer in the beginnging?" ] }, { "cell_type": "code", "execution_count": 198, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqStepper(Stepper):\n", " def step(self, xs, y, epoch):\n", " self.m.pr_force = (10 - epoch) * 0.1 if epoch < 10 else 0\n", " xtra = []\n", " output = self.m(*xs, y)\n", " if isinstance(output, tuple):\n", " output, *xtra = output\n", " self.opt.zero_grad()\n", " loss = raw_loss = self.crit(output, y)\n", " if self.reg_fn:\n", " loss = self.reg_fn(output, xtra, raw_loss)\n", " loss.backward()\n", " if self.clip: # gradient clipping\n", " nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip)\n", " self.opt.step()\n", " \n", " return raw_loss.data[0]" ] }, { "cell_type": "code", "execution_count": 199, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_TeacherForcing(nn.Module):\n", " def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):\n", " super().__init__()\n", " self.nl, self.nh, self.out_sl = nl, nh, out_sl\n", " self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)\n", " self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)\n", " self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)\n", " self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(em_sz_dec, len(itos_dec))\n", " self.out.weight.data = self.emb_dec.weight.data\n", " self.pr_force = 1. # new for teacher forcing\n", "\n", " def forward(self, inp, y=None): # argument y is new for teacher forcing\n", " sl, bs = inp.size()\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, h = self.gru_enc(emb, h)\n", " h = self.out_enc(h)\n", " \n", " dec_inp = V(torch.zeros(bs).long())\n", " res = []\n", " \n", " for i in range(self.out_sl):\n", " emb = self.emb_dec(dec_inp).unsqueeze(0)\n", " outp, h = self.gru_dec(emb, h)\n", " outp = self.out(self.out_drop(outp[0]))\n", " res.append(outp)\n", " dec_inp = V(outp.data.max(1)[1])\n", " \n", " if (dec_inp == 1).all():\n", " break\n", " if (y is not None) and (random.random() < self.pr_force): # new for teacher forcing\n", " if i >= len(y):\n", " break\n", " dec_inp = y[i]\n", " return torch.stack(res)\n", "\n", " def initHidden(self, bs):\n", " return V(torch.zeros(self.nl, bs, self.nh))" ] }, { "cell_type": "code", "execution_count": 200, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3097 ['l’', \"d'\", 't_up', 'd’', \"qu'\"]\n", "1285 [\"'s\", '’s', \"n't\", 'n’t', ':']\n" ] } ], "source": [ "rnn = Seq2SeqRNN_TeacherForcing(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90)\n", "learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn)\n", "learn.crit = seq2seq_loss" ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "beee982a4da04610b5522175ee71409b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=12), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 3.972275 11.894288 \n", " 1 3.75144 8.904335 \n", " 2 3.147096 5.737202 \n", " 3 3.205919 4.434411 \n", " 4 2.89941 4.337346 \n", " 5 2.837049 4.195613 \n", " 6 2.9374 3.801485 \n", " 7 2.919509 3.679037 \n", " 8 2.974855 3.600216 \n", " 9 2.98231 3.551779 \n", " 10 2.871864 3.418646 \n", " 11 2.674465 3.432893 \n" ] }, { "data": { "text/plain": [ "[array([3.43289])]" ] }, "execution_count": 201, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=12, use_clr=(20, 10), stepper=Seq2SeqStepper)" ] }, { "cell_type": "code", "execution_count": 202, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 203, "metadata": {}, "outputs": [], "source": [ "learn.save('forcing')\n", "learn.load('forcing')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Test**" ] }, { "cell_type": "code", "execution_count": 204, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quelles composantes des différents aspects de la performance devraient être mesurées , quelles données pertinentes recueillir et comment ? _eos_\n", "which components within various performance areas should be measured , whatkinds of data are appropriate to collect , and how should this be done ? _eos_\n", "what elements of the should be be be be and and and and ? ? ? ?\n", "\n", "le premier ministre doit - il nommer un ministre d’ état à la santé mentale , à la maladie mentale et à la toxicomanie ? _eos_\n", "what role can the federal government play to ensure that individuals with mental illness and addiction have access to the drug therapy they need ? _eos_\n", "what is the minister of the the of the and and and and and and mental health ? ? ? _eos_\n", "\n", "quelles sont les conséquences de la hausse des formes d’ emploi non conformes aux normes chez les travailleurs hautement qualifiés et chez ceux qui occupent des emplois plus marginaux ? _eos_\n", "what is the impact of growing forms of non - standard employment for highly skilled workers and for those employed in more marginal occupations ? _eos_\n", "what are the implications of of of of of of workers in in in and and and workers and workers and workers ? ? _eos_ _eos_\n", "\n", "que se produit - il si le gestionnaire n’ est pas en mesure de donner à l’ employé nommé pour une période déterminée un préavis de cessation d’ emploi d’ un mois ou s’ il néglige de le\n", "what happens if the manager is unable to or neglects to give a term employee the one - month notice of non - renewal ? _eos_\n", "what if if not is not a a or or or or or or or ? ? ? ? ? ? ?\n", "\n", "quelles personnes , communautés ou entités sont considérées comme potentiels i ) bénéficiaires de la protection et ii ) titulaires de droits ? _eos_\n", "which persons , communities or entities are identified as potential ( i ) beneficiaries of protection and / or ( ii ) rights holders ? _eos_\n", "who communities or persons , as as as as ( ( ( , protection and ? ? ? _eos_\n", "\n", "quelles conditions particulières doivent être remplies pendant l’ examen préliminaire international en ce qui concerne les listages des séquences de nucléotides ou d’ acides aminés ou les tableaux y relatifs ? _eos_\n", "what special requirements apply during the international preliminary examination to nucleotide and / or amino acid sequence listings and / or tables related thereto ? _eos_\n", "what special conditions to to to to the the / / / / sequence sequence sequence of of / / / / ? ? ? ? ? _eos_\n", "\n", "pourquoi cette soudaine réticence à promouvoir l’ égalité des genres et à protéger les femmes de ce que , dans la plupart des cas , on peut qualifier de violations grossières des droits humains ? _eos_\n", "why this sudden reluctance to effectively promote gender equality and protect women from what are – in many cases – egregious human rights violations ? _eos_\n", "why encourage such such such such such such as as human human human ? ? ? _eos_ _eos_\n", "\n", "pouvez - vous dire comment votre bagage culturel vous a aidée à aborder votre nouvelle vie au canada ( à vous adapter au mode de vie canadien ) ? _eos_\n", "what are some things from your cultural background that have helped you navigate canadian life ( helped you adjust to life in canada ) ? _eos_\n", "what are the you you you to to to to to to to to your your in in in in canada ? ? ? _eos_\n", "\n", "selon vous , quels seront , dans les dix prochaines années , les cinq enjeux les plus urgents en matière d' environnement et d' avenir viable pour vous et votre région ? _eos_\n", "which do you think will be the five most pressing environmental and sustainability issues for you and your region in the next ten years ? _eos_\n", "what do you see as the most most most important future and and and and future future future ? ? ? ? _eos_\n", "\n", "dans quelle mesure l’ expert est-il motivé et capable de partager ses connaissances , et dans quelle mesure son successeur est-il motivé et capable de recevoir ce savoir ? _eos_\n", "what is the expert ’s level of motivation and capability for sharing knowledge , and the successor ’s motivation and capability of acquiring it ? _eos_\n", "what is the expert of and and and and and and and and and and and and and ? ? ? ? ?\n", "\n" ] } ], "source": [ "x, y = next(iter(val_dl))\n", "probs = learn.model(V(x))\n", "preds = to_np(probs.max(2)[1])\n", "\n", "for i in range(180, 190):\n", " print(' '.join([ fr_itos[o] for o in x[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in y[:, i] if o != 1 ]))\n", " print(' '.join([ en_itos[o] for o in preds[:, i] if o != 1 ]))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attentional model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our RNN model exports the hidden state at every time step, along with the hidden state at the last time step. Initially we are only using the LAST hidden state to 'decode' into another phrase.\n", "\n", "Can we use the rest of those hidden states?\n", "\n", "**goal:** use some percentage of all hidden states and add another trainable parameter to find good answers in the model.\n", "\n", "**idea:** expecting the entire sentence to be summarized into a vector is a lot. Instead of having a hidden state at the end of the phrase, we can have a hidden state after every single word. So how do we use the hidden information after every word." ] }, { "cell_type": "code", "execution_count": 205, "metadata": {}, "outputs": [], "source": [ "def rand_t(*sz):\n", " return torch.randn(sz) / math.sqrt(sz[0])\n", "\n", "def rand_p(*sz):\n", " return nn.Parameter(rand_t(*sz))" ] }, { "cell_type": "code", "execution_count": 220, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqAttnRNN(nn.Module):\n", " def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):\n", " super().__init__()\n", " self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)\n", " self.nl,self.nh,self.out_sl = nl,nh,out_sl\n", " self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25)\n", " self.out_enc = nn.Linear(nh, em_sz_dec, bias=False)\n", " self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)\n", " self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(em_sz_dec, len(itos_dec))\n", " self.out.weight.data = self.emb_dec.weight.data\n", "\n", " # these 4 lines are addition for 'attention' \n", " self.W1 = rand_p(nh, em_sz_dec) # random matrix wrapped up in PyTorch Parameter\n", " self.l2 = nn.Linear(em_sz_dec, em_sz_dec) # this is the mini NN that will calculate the weights\n", " self.l3 = nn.Linear(em_sz_dec + nh, em_sz_dec)\n", " self.V = rand_p(em_sz_dec)\n", "\n", " def forward(self, inp, y=None, ret_attn=False):\n", " sl, bs = inp.size()\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, h = self.gru_enc(emb, h)\n", " h = self.out_enc(h)\n", "\n", " dec_inp = V(torch.zeros(bs).long())\n", " res, attns = [], [] # attns is addition for 'attention'\n", " w1e = enc_out @ self.W1 # this line is addition for 'attention'. matrix product.\n", "\n", " for i in range(self.out_sl):\n", " # these 5 lines are addition for 'attention'.\n", " \n", " # create a little neural network.\n", " # use softmax to generate the probabilities.\n", " w2h = self.l2(h[-1]) # take last layers hidden state put into linear layer\n", " u = F.tanh(w1e + w2h) # nonlinear activation\n", " a = F.softmax(u @ self.V, 0) # matrix product\n", " attns.append(a)\n", " # take a weighted average. Use the weights from mini NN.\n", " # note we are using all the encoder states\n", " Xa = (a.unsqueeze(2) * enc_out).sum(0)\n", " \n", " emb = self.emb_dec(dec_inp)\n", " # adding the hidden states to the encoder weights\n", " wgt_enc = self.l3(torch.cat([emb, Xa], 1)) # this line is addition for 'attention'\n", " \n", " outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h) # this line has changed for 'attention'\n", " outp = self.out(self.out_drop(outp[0]))\n", " res.append(outp)\n", " dec_inp = V(outp.data.max(1)[1])\n", " if (dec_inp==1).all():\n", " break\n", " if (y is not None) and (random.random() < self.pr_force): # why is teacher forcing logic still here? bug?\n", " if i >= len(y):\n", " break\n", " dec_inp = y[i]\n", "\n", " res = torch.stack(res)\n", " if ret_attn:\n", " res = res, torch.stack(attns) # bug?\n", " return res\n", "\n", " def initHidden(self, bs):\n", " return V(torch.zeros(self.nl, bs, self.nh))" ] }, { "cell_type": "code", "execution_count": 224, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3097 ['l’', \"d'\", 't_up', 'd’', \"qu'\"]\n", "1285 [\"'s\", '’s', \"n't\", 'n’t', ':']\n" ] } ], "source": [ "rnn = Seq2SeqAttnRNN(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90)\n", "learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn)\n", "learn.crit = seq2seq_loss" ] }, { "cell_type": "code", "execution_count": 208, "metadata": {}, "outputs": [], "source": [ "lr = 2e-3" ] }, { "cell_type": "code", "execution_count": 209, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "17cb27cf39134f16ae7c87780a8764a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 3.780541 14.757052 \n", " 1 3.221531 5.661915 \n", " 2 2.901307 4.924356 \n", " 3 2.875144 4.647381 \n", " 4 2.704298 3.912943 \n", " 5 2.69899 4.401953 \n", " 6 2.78165 3.864044 \n", " 7 2.765688 3.614325 \n", " 8 2.873574 3.417437 \n", " 9 2.826172 3.370511 \n", " 10 2.845763 3.293398 \n", " 11 2.66649 3.300835 \n", " 12 2.697862 3.258844 \n", " 13 2.659374 3.267969 \n", " 14 2.585613 3.240595 \n" ] }, { "data": { "text/plain": [ "[array([3.24059])]" ] }, "execution_count": 209, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=15, use_clr=(20, 10), stepper=Seq2SeqStepper)" ] }, { "cell_type": "code", "execution_count": 210, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 211, "metadata": {}, "outputs": [], "source": [ "learn.save('attn')" ] }, { "cell_type": "code", "execution_count": 226, "metadata": {}, "outputs": [], "source": [ "learn.load('attn')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Test**" ] }, { "cell_type": "code", "execution_count": 228, "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(val_dl))\n", "probs, attns = learn.model(V(x), ret_attn=True)\n", "preds = to_np(probs.max(2)[1])" ] }, { "cell_type": "code", "execution_count": 236, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quelles composantes des différents aspects de la performance devraient être mesurées , quelles données pertinentes recueillir et comment ? _eos_\n", "which components within various performance areas should be measured , whatkinds of data are appropriate to collect , and how should this be done ? _eos_\n", "what components of the performance should be be be data be and and how ? ? _eos_ ?\n", "\n", "le premier ministre doit - il nommer un ministre d’ état à la santé mentale , à la maladie mentale et à la toxicomanie ? _eos_\n", "what role can the federal government play to ensure that individuals with mental illness and addiction have access to the drug therapy they need ? _eos_\n", "what is the minister minister ’s minister minister to to minister to health health ? and mental mental health _eos_ _eos_ mental _eos_\n", "\n", "quelles sont les conséquences de la hausse des formes d’ emploi non conformes aux normes chez les travailleurs hautement qualifiés et chez ceux qui occupent des emplois plus marginaux ? _eos_\n", "what is the impact of growing forms of non - standard employment for highly skilled workers and for those employed in more marginal occupations ? _eos_\n", "what are the implications of of - statistics - workers - workers workers and and skilled workers workers workers older workers _eos_ ? workers ? _eos_ _eos_\n", "\n", "que se produit - il si le gestionnaire n’ est pas en mesure de donner à l’ employé nommé pour une période déterminée un préavis de cessation d’ emploi d’ un mois ou s’ il néglige de le\n", "what happens if the manager is unable to or neglects to give a term employee the one - month notice of non - renewal ? _eos_\n", "what if the manager is not to to employee employee employee a employee the employee for retirement time hours employee after a employee of ? after _eos_\n", "\n", "quelles personnes , communautés ou entités sont considérées comme potentiels i ) bénéficiaires de la protection et ii ) titulaires de droits ? _eos_\n", "which persons , communities or entities are identified as potential ( i ) beneficiaries of protection and / or ( ii ) rights holders ? _eos_\n", "who , or or or or considered as as recipients of of of protection protection protection _eos_ ? _eos_ _eos_\n", "\n", "quelles conditions particulières doivent être remplies pendant l’ examen préliminaire international en ce qui concerne les listages des séquences de nucléotides ou d’ acides aminés ou les tableaux y relatifs ? _eos_\n", "what special requirements apply during the international preliminary examination to nucleotide and / or amino acid sequence listings and / or tables related thereto ? _eos_\n", "what specific conditions conditions be be during the international examination examination in the for nucleotide or amino amino / or or ? _eos_ ? ? _eos_ tables _eos_ ?\n", "\n", "pourquoi cette soudaine réticence à promouvoir l’ égalité des genres et à protéger les femmes de ce que , dans la plupart des cas , on peut qualifier de violations grossières des droits humains ? _eos_\n", "why this sudden reluctance to effectively promote gender equality and protect women from what are – in many cases – egregious human rights violations ? _eos_\n", "why this this to to to to to to women to and and and women to , of _eos_ of many people ? ? of _eos_ ? human human\n", "\n", "pouvez - vous dire comment votre bagage culturel vous a aidée à aborder votre nouvelle vie au canada ( à vous adapter au mode de vie canadien ) ? _eos_\n", "what are some things from your cultural background that have helped you navigate canadian life ( helped you adjust to life in canada ) ? _eos_\n", "what is your your of your you to you to to in canada canada canada life canada canada canada _eos_ _eos_ _eos_ _eos_ _eos_\n", "\n", "selon vous , quels seront , dans les dix prochaines années , les cinq enjeux les plus urgents en matière d' environnement et d' avenir viable pour vous et votre région ? _eos_\n", "which do you think will be the five most pressing environmental and sustainability issues for you and your region in the next ten years ? _eos_\n", "what do you think in the next five five next , , next and and and and and and you and in ? _eos_ ? ? _eos_ ?\n", "\n", "dans quelle mesure l’ expert est-il motivé et capable de partager ses connaissances , et dans quelle mesure son successeur est-il motivé et capable de recevoir ce savoir ? _eos_\n", "what is the expert ’s level of motivation and capability for sharing knowledge , and the successor ’s motivation and capability of acquiring it ? _eos_\n", "what is the the of the the and and and and and and and to and to and and ? ? ? _eos_ _eos_\n", "\n" ] } ], "source": [ "for i in range(180, 190):\n", " print(' '.join([fr_itos[o] for o in x[:, i] if o != 1]))\n", " print(' '.join([en_itos[o] for o in y[:, i] if o != 1]))\n", " print(' '.join([en_itos[o] for o in preds[:, i] if o != 1]))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Visualization" ] }, { "cell_type": "code", "execution_count": 237, "metadata": {}, "outputs": [], "source": [ "attn = to_np(attns[..., 180])" ] }, { "cell_type": "code", "execution_count": 248, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(29, 38)\n", "(38,)\n", "[0.00093 0.38696 0.50663 0.05482 0.00831 0.0026 0.00418 0.00047 0.00101 0.00141]\n" ] } ], "source": [ "# DEBUG\n", "print(attn.shape)\n", "\n", "# graph 1\n", "print(attn[0].shape)\n", "print(attn[0][:10])\n", "# END DEBUG" ] }, { "cell_type": "code", "execution_count": 241, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(3, 3, figsize=(15, 10))\n", "for i, ax in enumerate(axes.flat):\n", " ax.plot(attn[i])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## All (seq2seq + bi-directional + attention)" ] }, { "cell_type": "code", "execution_count": 249, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_All(nn.Module):\n", " def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2):\n", " super().__init__()\n", " self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc)\n", " self.nl, self.nh, self.out_sl = nl, nh, out_sl\n", " self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25, bidirectional=True)\n", " self.out_enc = nn.Linear(nh * 2, em_sz_dec, bias=False)\n", " self.drop_enc = nn.Dropout(0.25)\n", " self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec)\n", " self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1)\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(em_sz_dec, len(itos_dec))\n", " self.out.weight.data = self.emb_dec.weight.data\n", "\n", " self.W1 = rand_p(nh * 2, em_sz_dec)\n", " self.l2 = nn.Linear(em_sz_dec, em_sz_dec)\n", " self.l3 = nn.Linear(em_sz_dec + nh * 2, em_sz_dec)\n", " self.V = rand_p(em_sz_dec)\n", "\n", " def forward(self, inp, y=None):\n", " sl, bs = inp.size()\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, h = self.gru_enc(emb, h)\n", " h = h.view(2, 2, bs, -1).permute(0, 2, 1, 3).contiguous().view(2, bs, -1)\n", " h = self.out_enc(self.drop_enc(h))\n", "\n", " dec_inp = V(torch.zeros(bs).long())\n", " res, attns = [], []\n", " w1e = enc_out @ self.W1\n", " for i in range(self.out_sl):\n", " w2h = self.l2(h[-1])\n", " u = F.tanh(w1e + w2h)\n", " a = F.softmax(u @ self.V, 0)\n", " attns.append(a)\n", " Xa = (a.unsqueeze(2) * enc_out).sum(0)\n", " emb = self.emb_dec(dec_inp)\n", " wgt_enc = self.l3(torch.cat([emb, Xa], 1))\n", " \n", " outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)\n", " outp = self.out(self.out_drop(outp[0]))\n", " res.append(outp)\n", " dec_inp = V(outp.data.max(1)[1])\n", " if (dec_inp == 1).all():\n", " break\n", " if (y is not None) and (random.random() < self.pr_force):\n", " if i >= len(y):\n", " break\n", " dec_inp = y[i]\n", " return torch.stack(res)\n", "\n", " def initHidden(self, bs):\n", " return V(torch.zeros(self.nl * 2, bs, self.nh))" ] }, { "cell_type": "code", "execution_count": 250, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3097 ['l’', \"d'\", 't_up', 'd’', \"qu'\"]\n", "1285 [\"'s\", '’s', \"n't\", 'n’t', ':']\n" ] } ], "source": [ "rnn = Seq2SeqRNN_All(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90)\n", "learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn)\n", "learn.crit = seq2seq_loss" ] }, { "cell_type": "code", "execution_count": 251, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5b77ca806c5432a92631c29ed5642b4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 3.848361 13.931393 \n", " 1 3.392174 6.029147 \n", " 2 3.092935 4.806828 \n", " 3 2.786907 4.489991 \n", " 4 2.630582 4.586274 \n", " 5 2.865972 4.124985 \n", " 6 2.795535 3.689954 \n", " 7 2.744906 3.453802 \n", " 8 2.713428 3.494858 \n", " 9 2.882232 3.303313 \n", " 10 2.613028 3.27942 \n", " 11 2.597443 3.218975 \n", " 13 2.469049 3.22439 \n", " 14 2.278426 3.235787 \n" ] }, { "data": { "text/plain": [ "[array([3.23579])]" ] }, "execution_count": 251, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit(lr, 1, cycle_len=15, use_clr=(20, 10), stepper=Seq2SeqStepper)" ] }, { "cell_type": "code", "execution_count": 252, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 253, "metadata": {}, "outputs": [], "source": [ "learn.save('all')" ] }, { "cell_type": "code", "execution_count": 254, "metadata": {}, "outputs": [], "source": [ "learn.load('all')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Test**" ] }, { "cell_type": "code", "execution_count": 255, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "quelles composantes des différents aspects de la performance devraient être mesurées , quelles données pertinentes recueillir et comment ? _eos_\n", "which components within various performance areas should be measured , whatkinds of data are appropriate to collect , and how should this be done ? _eos_\n", "what components of the different aspects of should be be measured , and and how how ? _eos_\n", "\n", "le premier ministre doit - il nommer un ministre d’ état à la santé mentale , à la maladie mentale et à la toxicomanie ? _eos_\n", "what role can the federal government play to ensure that individuals with mental illness and addiction have access to the drug therapy they need ? _eos_\n", "who is the minister minister to minister minister to mental health mental and mental ? ? ? _eos_\n", "\n", "quelles sont les conséquences de la hausse des formes d’ emploi non conformes aux normes chez les travailleurs hautement qualifiés et chez ceux qui occupent des emplois plus marginaux ? _eos_\n", "what is the impact of growing forms of non - standard employment for highly skilled workers and for those employed in more marginal occupations ? _eos_\n", "what are the implications of increasing employment forms of workers workers workers workers workers workers workers workers workers workers workers workers more more ? _eos_ _eos_ _eos_\n", "\n", "que se produit - il si le gestionnaire n’ est pas en mesure de donner à l’ employé nommé pour une période déterminée un préavis de cessation d’ emploi d’ un mois ou s’ il néglige de le\n", "what happens if the manager is unable to or neglects to give a term employee the one - month notice of non - renewal ? _eos_\n", "what happens the manager does not to to the employee employee employee employee a a employee a employee or employee or or or or or or or or ?\n", "\n", "quelles personnes , communautés ou entités sont considérées comme potentiels i ) bénéficiaires de la protection et ii ) titulaires de droits ? _eos_\n", "which persons , communities or entities are identified as potential ( i ) beneficiaries of protection and / or ( ii ) rights holders ? _eos_\n", "who , communities communities or entities as as potential as beneficiaries of ( protection and protection protection protection ? ? _eos_ _eos_ _eos_\n", "\n", "quelles conditions particulières doivent être remplies pendant l’ examen préliminaire international en ce qui concerne les listages des séquences de nucléotides ou d’ acides aminés ou les tableaux y relatifs ? _eos_\n", "what special requirements apply during the international preliminary examination to nucleotide and / or amino acid sequence listings and / or tables related thereto ? _eos_\n", "what special conditions must be required during the international preliminary preliminary in for nucleotide or sequence amino or sequence or or or or tables ? ? _eos_ _eos_ _eos_\n", "\n", "pourquoi cette soudaine réticence à promouvoir l’ égalité des genres et à protéger les femmes de ce que , dans la plupart des cas , on peut qualifier de violations grossières des droits humains ? _eos_\n", "why this sudden reluctance to effectively promote gender equality and protect women from what are – in many cases – egregious human rights violations ? _eos_\n", "why this sudden effect of to to women women women and and of of of , , of of of human human human human ? _eos_ _eos_ _eos_ _eos_\n", "\n", "pouvez - vous dire comment votre bagage culturel vous a aidée à aborder votre nouvelle vie au canada ( à vous adapter au mode de vie canadien ) ? _eos_\n", "what are some things from your cultural background that have helped you navigate canadian life ( helped you adjust to life in canada ) ? _eos_\n", "what can you you your your cultural your your you to to to canada canada canada life life life life canada ? _eos_\n", "\n", "selon vous , quels seront , dans les dix prochaines années , les cinq enjeux les plus urgents en matière d' environnement et d' avenir viable pour vous et votre région ? _eos_\n", "which do you think will be the five most pressing environmental and sustainability issues for you and your region in the next ten years ? _eos_\n", "what do you see be be the the next five five five , and and and and and and and your your ? ? ? _eos_ _eos_ _eos_\n", "\n", "dans quelle mesure l’ expert est-il motivé et capable de partager ses connaissances , et dans quelle mesure son successeur est-il motivé et capable de recevoir ce savoir ? _eos_\n", "what is the expert ’s level of motivation and capability for sharing knowledge , and the successor ’s motivation and capability of acquiring it ? _eos_\n", "what is the expert ’s and and knowledge knowledge knowledge knowledge and and and and and and and and and and and ? ? ? _eos_ _eos_\n", "\n" ] } ], "source": [ "x,y = next(iter(val_dl))\n", "probs = learn.model(V(x))\n", "preds = to_np(probs.max(2)[1])\n", "\n", "for i in range(180, 190):\n", " print(' '.join([fr_itos[o] for o in x[:, i] if o != 1]))\n", " print(' '.join([en_itos[o] for o in y[:, i] if o != 1]))\n", " print(' '.join([en_itos[o] for o in preds[:, i] if o != 1]))\n", " print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }