{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ULMFiT Language Model in Malay Language" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Project\n", "\n", "### Background\n", "\n", "_This work is part of my project while studying [fast.ai's 2018 edition of Cutting Edge Deep Learning for Coders, Part 2](http://course.fast.ai/part2.html) course._\n", "\n", "I took this opportunity to implement [Universal Language Model Fine-tuning for Text Classification (ULMFiT) paper](http://nlp.fast.ai/classification/2018/05/15/introducting-ulmfit.html) in different languages together with the fast.ai community. fast.ai will soon launch a [model zoo with pre-trained language models for many languages](http://forums.fast.ai/t/language-model-zoo-gorilla/14623). You can learn more about ULMFiT in lesson 4 and lesson 10. What I learned from lesson 10 ([my notes](https://cedrickchee.gitbook.io/knowledge/courses/fast.ai/deep-learning-part-2-cutting-edge-deep-learning-for-coders/2018-edition/lesson-10-transfer-learning-nlp)) is:\n", "- how pre-training a full language model from scratch can greatly surpass previous approaches based on simple word vectors\n", "- transfer learning for NLP by using this language model to show a new state of the art result in text classification, in some sense like [NLP's ImageNet moment has arrived](http://ruder.io/nlp-imagenet/)\n", "\n", "---\n", "\n", "### Project Goal\n", "\n", "The goal of this project is to train Malay word embeddings using the fast.ai version of [AWD-LSTM Language Model](https://arxiv.org/abs/1708.02182) by Salesforce Research—basically LSTM with dropouts—with data from [Wikipedia](https://dumps.wikimedia.org/mswiki/20180901/mswiki-20180901-pages-articles.xml.bz2) (last updated Sept 2, 2018). The AWD-LSTM language model achieved the state of the art performance on the English language.\n", "\n", "Using 90/10 train-validation split, I achieved perplexity of **29.30245 with 60,002 embeddings at 400 dimensions**, compared to state-of-the-art as of June 12, 2018 at **40.68 for English WikiText-2 by [Yang et al (2017)](https://arxiv.org/abs/1711.03953)** and **29.2 for English WikiText-103 by [Rae et al (2018)](https://arxiv.org/abs/1803.10049)**. To the best of my knowledge, there is no comparable research in Malay language at the point of writing (Sept 21, 2018).\n", "\n", "My workflow is as follows:\n", "- Perform 90/10 train-validation split\n", "- Minimal text cleaning and tokenization using our own tokenizer\n", "- Train language model\n", "- Evaluate model based on perplexity and eyeballing\n", "- Get embeddings of dataset from train set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Pre-trained model**\n", "\n", "You can download the files from [Google Drive](https://drive.google.com/drive/folders/1p5fsrD97iRD-Vz6C_ae5fo4c5wY0KrJd?usp=sharing):\n", "- Weights for the pre-trained model (lm_malay_final.h5.tar.gz)\n", " - Uncompress and put the weights (.h5 file) into `{project_root}/data/models/`.\n", "- Index-to-word mapping (itos.pkl.tar.gz)\n", " - Uncompress and put the pickled objects (.pkl files) into `{project_root}/data/model/malay/tmp/`.\n", "- Pre-processed training dataset of Malay Wikipedia:\n", " - tokenized training text data (tok_trn.npy.tar.gz)\n", " - tokenized validation text data (tok_val.npy.tar.gz)\n", " - indexed representation of train set (trn_ids.npy.tar.gz)\n", " - indexed representation of validation set (val_ids.npy.tar.gz)\n", " - Uncompress and put the numpy array binary (.npy files) into `{project_root}/data/model/malay/tmp/`.\n", "\n", "Note:\n", "\n", "The weights (model state dict) and the optimizer state for the model were saved at the end of the training.\n", "\n", "_Note: the model was last trained on 2018-09-22 and the weights last updated on 2018-09-22._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import pathlib\n", "import html\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from fastai.text import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Standardize data format" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "BOS = \"xbos\" # beginning-of-sentence tag\n", "FLD = \"xfld\" # data field tag\n", "\n", "DATA_PATH = \"data\"\n", "EXTR_PATH = pathlib.Path(f\"{DATA_PATH}/wiki_extr/ms\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "LM_PATH = Path(f\"{DATA_PATH}/model/malay/\")\n", "LM_PATH.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "232\n" ] }, { "data": { "text/plain": [ "['data/wiki_extr/ms/AB/wiki_65',\n", " 'data/wiki_extr/ms/AB/wiki_11',\n", " 'data/wiki_extr/ms/AB/wiki_62',\n", " 'data/wiki_extr/ms/AB/wiki_84',\n", " 'data/wiki_extr/ms/AB/wiki_23']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_FILENAMES = [str(f) for f in EXTR_PATH.rglob(\"*/*\")]\n", "print(len(LANG_FILENAMES))\n", "LANG_FILENAMES[0:5]" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = []\n", "for i in LANG_FILENAMES:\n", " for line in open(i):\n", " LANG_TEXT.append(json.loads(line))\n", "\n", "LANG_TEXT = pd.DataFrame(LANG_TEXT)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idtexttitleurl
0666100Senarai penyakit anjing\\n\\nSenarai penyakit an...Senarai penyakit anjinghttps://ms.wikipedia.org/wiki?curid=666100
1666132Son\\n\\nSon, adalah satu siri televisyen Turki ...Sonhttps://ms.wikipedia.org/wiki?curid=666132
2666144Kanashiki Amefuri / Adam to Eve no Dilemma\\n\\n...Kanashiki Amefuri / Adam to Eve no Dilemmahttps://ms.wikipedia.org/wiki?curid=666144
3666153Bulakan, Cibeber\\n\\nBulakan merupakan sebuah K...Bulakan, Cibeberhttps://ms.wikipedia.org/wiki?curid=666153
4666155Cibeber, Cibeber, Cilegon\\n\\nCibeber merupakan...Cibeber, Cibeber, Cilegonhttps://ms.wikipedia.org/wiki?curid=666155
\n", "
" ], "text/plain": [ " id text \\\n", "0 666100 Senarai penyakit anjing\\n\\nSenarai penyakit an... \n", "1 666132 Son\\n\\nSon, adalah satu siri televisyen Turki ... \n", "2 666144 Kanashiki Amefuri / Adam to Eve no Dilemma\\n\\n... \n", "3 666153 Bulakan, Cibeber\\n\\nBulakan merupakan sebuah K... \n", "4 666155 Cibeber, Cibeber, Cilegon\\n\\nCibeber merupakan... \n", "\n", " title \\\n", "0 Senarai penyakit anjing \n", "1 Son \n", "2 Kanashiki Amefuri / Adam to Eve no Dilemma \n", "3 Bulakan, Cibeber \n", "4 Cibeber, Cibeber, Cilegon \n", "\n", " url \n", "0 https://ms.wikipedia.org/wiki?curid=666100 \n", "1 https://ms.wikipedia.org/wiki?curid=666132 \n", "2 https://ms.wikipedia.org/wiki?curid=666144 \n", "3 https://ms.wikipedia.org/wiki?curid=666153 \n", "4 https://ms.wikipedia.org/wiki?curid=666155 " ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT.head()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "# Getting rid of the title name in the text field\n", "def split_title_from_text(text):\n", " words = text.split(\"\\n\\n\")\n", " if len(words) >= 2:\n", " return ''.join(words[1:])\n", " else:\n", " return ''.join(words)\n", " \n", "LANG_TEXT[\"text\"] = LANG_TEXT[\"text\"].apply(lambda x: split_title_from_text(x))" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idtexttitleurl
0666100Senarai penyakit anjing ialah pilihan penyakit...Senarai penyakit anjinghttps://ms.wikipedia.org/wiki?curid=666100
1666132Son, adalah satu siri televisyen Turki yang di...Sonhttps://ms.wikipedia.org/wiki?curid=666132
2666144Kanashiki Amefuri / Adam to Eve no Dilemma (悲し...Kanashiki Amefuri / Adam to Eve no Dilemmahttps://ms.wikipedia.org/wiki?curid=666144
3666153Bulakan merupakan sebuah Kelurahan yang terlet...Bulakan, Cibeberhttps://ms.wikipedia.org/wiki?curid=666153
4666155Cibeber merupakan sebuah desa yang terletak da...Cibeber, Cibeber, Cilegonhttps://ms.wikipedia.org/wiki?curid=666155
\n", "
" ], "text/plain": [ " id text \\\n", "0 666100 Senarai penyakit anjing ialah pilihan penyakit... \n", "1 666132 Son, adalah satu siri televisyen Turki yang di... \n", "2 666144 Kanashiki Amefuri / Adam to Eve no Dilemma (悲し... \n", "3 666153 Bulakan merupakan sebuah Kelurahan yang terlet... \n", "4 666155 Cibeber merupakan sebuah desa yang terletak da... \n", "\n", " title \\\n", "0 Senarai penyakit anjing \n", "1 Son \n", "2 Kanashiki Amefuri / Adam to Eve no Dilemma \n", "3 Bulakan, Cibeber \n", "4 Cibeber, Cibeber, Cilegon \n", "\n", " url \n", "0 https://ms.wikipedia.org/wiki?curid=666100 \n", "1 https://ms.wikipedia.org/wiki?curid=666132 \n", "2 https://ms.wikipedia.org/wiki?curid=666144 \n", "3 https://ms.wikipedia.org/wiki?curid=666153 \n", "4 https://ms.wikipedia.org/wiki?curid=666155 " ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT.head()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT.to_csv(f\"{LM_PATH}/wiki_malay_corpus.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = pd.read_csv(f\"{LM_PATH}/wiki_malay_corpus.csv\")" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = LANG_TEXT.assign(length = 0)\n", "LANG_TEXT.columns = [\"id\", \"text\", \"title\", \"url\", \"length\"]\n", "LANG_TEXT = LANG_TEXT.assign(labels = 0).pipe(lambda x: x[[\"labels\", \"text\", \"length\"]])" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelstextlength
00Senarai penyakit anjing ialah pilihan penyakit...0
10Son, adalah satu siri televisyen Turki yang di...0
20Kanashiki Amefuri / Adam to Eve no Dilemma (悲し...0
30Bulakan merupakan sebuah Kelurahan yang terlet...0
40Cibeber merupakan sebuah desa yang terletak da...0
\n", "
" ], "text/plain": [ " labels text length\n", "0 0 Senarai penyakit anjing ialah pilihan penyakit... 0\n", "1 0 Son, adalah satu siri televisyen Turki yang di... 0\n", "2 0 Kanashiki Amefuri / Adam to Eve no Dilemma (悲し... 0\n", "3 0 Bulakan merupakan sebuah Kelurahan yang terlet... 0\n", "4 0 Cibeber merupakan sebuah desa yang terletak da... 0" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT.head()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT[\"length\"] = LANG_TEXT[\"text\"].str.len()\n", "LANG_TEXT = LANG_TEXT.sort_values(by=[\"length\"], ascending=False)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelstextlength
2985560Dalam agama Islam, Asmaul Husna () merupakan s...200069.0
2876740Tasbih Cinta merupakan sebuah Drama lipur lara...154934.0
2897050Ketuanan Melayu merupakan kontrak sosial yang ...101706.0
3153320Dato' Sri Siti Nurhaliza binti Tarudin (Jawi: ...98675.0
951260Kebebasan beragama di Malaysia adalah tertaklu...94949.0
\n", "
" ], "text/plain": [ " labels text length\n", "298556 0 Dalam agama Islam, Asmaul Husna () merupakan s... 200069.0\n", "287674 0 Tasbih Cinta merupakan sebuah Drama lipur lara... 154934.0\n", "289705 0 Ketuanan Melayu merupakan kontrak sosial yang ... 101706.0\n", "315332 0 Dato' Sri Siti Nurhaliza binti Tarudin (Jawi: ... 98675.0\n", "95126 0 Kebebasan beragama di Malaysia adalah tertaklu... 94949.0" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT.head()" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT.to_csv(f\"{LM_PATH}/wiki_malay_corpus_sorted_by_len.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = pd.read_csv(f\"{LM_PATH}/wiki_malay_corpus_sorted_by_len.csv\")" ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "319374" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(LANG_TEXT)" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labelstextlength
00Dalam agama Islam, Asmaul Husna () merupakan s...200069.0
10Tasbih Cinta merupakan sebuah Drama lipur lara...154934.0
20Ketuanan Melayu merupakan kontrak sosial yang ...101706.0
30Dato' Sri Siti Nurhaliza binti Tarudin (Jawi: ...98675.0
40Kebebasan beragama di Malaysia adalah tertaklu...94949.0
\n", "
" ], "text/plain": [ " labels text length\n", "0 0 Dalam agama Islam, Asmaul Husna () merupakan s... 200069.0\n", "1 0 Tasbih Cinta merupakan sebuah Drama lipur lara... 154934.0\n", "2 0 Ketuanan Melayu merupakan kontrak sosial yang ... 101706.0\n", "3 0 Dato' Sri Siti Nurhaliza binti Tarudin (Jawi: ... 98675.0\n", "4 0 Kebebasan beragama di Malaysia adalah tertaklu... 94949.0" ] }, "execution_count": 136, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT.head()" ] }, { "cell_type": "code", "execution_count": 137, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = LANG_TEXT[LANG_TEXT['length'] > 10]" ] }, { "cell_type": "code", "execution_count": 138, "metadata": {}, "outputs": [], "source": [ "LANG_TEXT = LANG_TEXT.iloc[0:1000000]" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "316081" ] }, "execution_count": 139, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(LANG_TEXT)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Some statistics of Malay Wikipedia" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Number of documents" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 Dalam agama Islam, Asmaul Husna () merupakan s...\n", "1 Tasbih Cinta merupakan sebuah Drama lipur lara...\n", "2 Ketuanan Melayu merupakan kontrak sosial yang ...\n", "3 Dato' Sri Siti Nurhaliza binti Tarudin (Jawi: ...\n", "4 Kebebasan beragama di Malaysia adalah tertaklu...\n", "5 England sudah dihuni manusia sejak lebih darip...\n", "6 Perang Bosnia merupakan sebuah konflik bersenj...\n", "7 Masakan Zaman Pertengahan termasuk makanan, ta...\n", "8 Perang Vietnam, ada kalinya disebut juga \"Pera...\n", "9 Operasi Market-Garden adalah satu operasi sera...\n", "Name: text, dtype: object\n" ] }, { "data": { "text/plain": [ "(316081, 3)" ] }, "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(LANG_TEXT[\"text\"][:10])\n", "LANG_TEXT.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Number of words in all the documents" ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "26509826" ] }, "execution_count": 141, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LANG_TEXT[\"text\"].apply(lambda x: len(x.split(\" \"))).sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Number of unique tokens across documents" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1742401" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(set(\"\".join(LANG_TEXT[\"text\"].values).split(\" \")))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Text processing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We perform the following minimal text processing:\n", "\n", "- Remove html tags\n", "- The token `xbos` is used to note start of a text since we will be chaining them together for the language model training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clean plain text" ] }, { "cell_type": "code", "execution_count": 155, "metadata": {}, "outputs": [], "source": [ "re1 = re.compile(r' +')\n", "\n", "def fixup(x):\n", " x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace(\n", " 'nbsp;', ' ').replace('#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace(\n", " '
', \"\\n\").replace('\\\\\"', '\"').replace('','u_n').replace(' @.@ ','.').replace(\n", " ' @-@ ','-').replace('\\\\', ' \\\\ ')\n", " return re1.sub(' ', html.unescape(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenize" ] }, { "cell_type": "code", "execution_count": 156, "metadata": {}, "outputs": [], "source": [ "def get_texts(df, n_lbls=1):\n", " labels = df.iloc[:, range(n_lbls)].values.astype(np.int64)\n", " texts = f\"\\n{BOS} {FLD} 1 \" + df[n_lbls].astype(str)\n", " for i in range(n_lbls + 1, len(df.columns)):\n", " texts += f\" {FLD} {i-n_lbls} \" + df[i].astype(str)\n", " texts = texts.apply(fixup).values.astype(str)\n", "\n", " tok = Tokenizer().proc_all_mp(partition_by_cores(texts)) # splits the list into sublists for processing by each core\n", " # Lower and upper case is inside the tokenizer\n", " return tok, list(labels)\n", "\n", "def get_all(df, n_lbls):\n", " tok, labels = [], []\n", " for i, r in enumerate(df):\n", " print(i)\n", " #pdb.set_trace()\n", " tok_, labels_ = get_texts(r, n_lbls)\n", " tok += tok_;\n", " labels += labels_\n", " return tok, labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create validation set" ] }, { "cell_type": "code", "execution_count": 145, "metadata": {}, "outputs": [], "source": [ "# Split the data into train and validation sets\n", "# Splitting 10% for validation.\n", "trn_texts, val_texts = sklearn.model_selection.train_test_split(LANG_TEXT, test_size=0.1)" ] }, { "cell_type": "code", "execution_count": 166, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(284472, 31609)" ] }, "execution_count": 166, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(trn_texts), len(val_texts)" ] }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [], "source": [ "np.random.seed(42)\n", "\n", "trn_idx = np.random.permutation(len(trn_texts)) # generate a random ordering\n", "val_idx = np.random.permutation(len(val_texts))\n", "\n", "df_trn = trn_texts.iloc[trn_idx, :] # sort things randomly\n", "df_val = val_texts.iloc[val_idx, :] # sort things randomly\n", "\n", "df_trn.columns = [\"labels\", \"text\", \"length\"]\n", "df_val.columns = [\"labels\", \"text\", \"length\"]\n", "\n", "df_trn.to_csv(LM_PATH / \"train.csv\", header=False, index=False)\n", "df_val.to_csv(LM_PATH / \"test.csv\", header=False, index=False) # saving the data in our new format to disk" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [], "source": [ "chunksize = 10000\n", "df_trn = pd.read_csv(LM_PATH / \"train.csv\", header=None, chunksize=chunksize)\n", "df_val = pd.read_csv(LM_PATH / \"test.csv\", header=None, chunksize=chunksize)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preparation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We write a simple custom function to perform tokenization. Out of 24,911,449 tokens from all of training set, we chose 60,000 for the vocabulary size (plus one for unknown and another for padding) of tokens which are not rare words (appeared more than twice and not typos) in the training set." ] }, { "cell_type": "code", "execution_count": 158, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1\n", "2\n", "3\n", "4\n", "5\n", "6\n", "7\n", "8\n", "9\n", "10\n", "11\n", "12\n", "13\n", "14\n", "15\n", "16\n", "17\n", "18\n", "19\n", "20\n", "21\n", "22\n", "0\n", "1\n", "2\n", "3\n" ] } ], "source": [ "# Finally, tokenize text data\n", "tok_trn, trn_labels = get_all(df_trn, 1)\n", "tok_val, val_labels = get_all(df_val, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data loader" ] }, { "cell_type": "code", "execution_count": 159, "metadata": {}, "outputs": [], "source": [ "# Create a tmp directory to store the upcoming numpy arrays\n", "(LM_PATH / \"tmp\").mkdir(exist_ok=True)\n", "\n", "# Save the train and validation tokens in the tmp directories\n", "np.save(LM_PATH / \"tmp\" / \"tok_trn.npy\", tok_trn)\n", "np.save(LM_PATH / \"tmp\" / \"tok_val.npy\", tok_val)" ] }, { "cell_type": "code", "execution_count": 160, "metadata": {}, "outputs": [], "source": [ "tok_trn = np.load(LM_PATH / \"tmp\" / \"tok_trn.npy\")\n", "tok_val = np.load(LM_PATH / \"tmp\" / \"tok_val.npy\")" ] }, { "cell_type": "code", "execution_count": 161, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([list(['\\n', 'xbos', 'xfld', '1', 'gräfenhainichen', 'merupakan', 'sebuah', 'bandar', 'terletak', 'di', 'daerah', 'wittenberg', ',', 'saxony', '-', 'anhalt', ',', 'jerman', '.', '\\n ', 'xfld', '1', '94.0']),\n", " list(['\\n', 'xbos', 'xfld', '1', '19', 'merupakan', 'tahun', 'biasa', 'yang', 'bermula', 'pada', 'hari', 'ahad', 'dalam', 'kalendar', 'gregory', '.', '\\n', '<', 'br', 'clear', '=', 'all', '>', '\\n ', 'xfld', '1', '92.0'])],\n", " dtype=object)" ] }, "execution_count": 161, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sanity check: get the first 2 train tokens\n", "tok_trn[:2]" ] }, { "cell_type": "code", "execution_count": 162, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0
,1102901
.999701
yang548167
di532424
dan509869
1465667
xfld448944
\\n403986
\"378261
-355521
pada232686
xbos224472
dalam210035
(202054
)195057
merupakan189399
t_up188865
sebuah184610
dengan174237
untuk151121
ini145098
terletak138483
\\n134853
tahun115035
dari110815
\n", "
" ], "text/plain": [ " 0\n", ", 1102901\n", ". 999701\n", "yang 548167\n", "di 532424\n", "dan 509869\n", "1 465667\n", "xfld 448944\n", "\\n 403986\n", "\" 378261\n", "- 355521\n", "pada 232686\n", "xbos 224472\n", "dalam 210035\n", "( 202054\n", ") 195057\n", "merupakan 189399\n", "t_up 188865\n", "sebuah 184610\n", "dengan 174237\n", "untuk 151121\n", "ini 145098\n", "terletak 138483\n", "\\n 134853\n", "tahun 115035\n", "dari 110815" ] }, "execution_count": 162, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the Counter object from all the splitted files.\n", "# Identify the most common tokens\n", "freq = Counter(p for o in tok_trn for p in o) \n", "freqs = pd.DataFrame.from_dict(freq, orient=\"index\")\n", "freqs.sort_values(0, ascending=False).head(25)" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "24911449" ] }, "execution_count": 163, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sanity check\n", "len([p for o in tok_trn for p in o])" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [], "source": [ "cnt = []\n", "for i in range(49):\n", " row_cnt = freqs[freqs[0]>=i+1].shape[0]\n", " cnt.append(row_cnt)" ] }, { "cell_type": "code", "execution_count": 165, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 165, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAD8CAYAAACLrvgBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3X2QHNV57/HvMzM7szsjidXLIrBekAAJjEkQsBHCOAaDLSRsLO6NA7jiIHMpyxVIYipQCVBJERvj2HXjgKkYxVQgCMcxqLAJgggrQkZ2fG0wwsgIkIFFBiRZb+htJa32/bl/9JnVaDU7MytpXqT5faq6uvvp031OJ2sene7TZ8zdERERKUWs2g0QEZFjh5KGiIiUTElDRERKpqQhIiIlU9IQEZGSKWmIiEjJiiYNMzvDzFbnLO1mdrOZjTGz5Wb2VliPDuXNzO4zszYze8XMzsu51vxQ/i0zm58TP9/M1oRz7jMzC/G8dYiISHUUTRru/oa7z3D3GcD5QAfwBHAbsMLdpwErwj7AXGBaWBYACyFKAMCdwAXATODOnCSwEPhCznlzQnyoOkREpAqG+3jqMuBtd38XmAcsCvFFwFVhex7wiEeeB5rN7GTgcmC5u+9w953AcmBOODbK3Z/36EvDRwZdK18dIiJSBYlhlr8W+H7YHu/um8L2ZmB82J4ArM85Z0OIFYpvyBMvVMeQxo0b51OmTCnlXo7cG29E6zPOqEx9IiJl8tJLL73v7i3FypWcNMwsCXwauH3wMXd3MyvrfCSF6jCzBUSPwpg8eTKrVq0qZ1MOuOSSaL1yZWXqExEpEzN7t5Ryw3k8NRf4lbtvCftbwqMlwnpriG8EJuWcNzHECsUn5okXquMg7v6Au7e6e2tLS9FEKSIih2k4SeOzHHg0BbAEyI6Amg88mRO/LoyimgXsDo+YlgGzzWx0eAE+G1gWjrWb2awwauq6QdfKV0dt+Nu/jRYRkTpR0uMpM8sAnwC+mBP+OrDYzG4A3gWuDvGlwBVAG9FIq+sB3H2Hmd0FvBjKfcXdd4TtG4GHgSbgmbAUqqM2fPzj1W6BiEhF2fE2NXpra6tX7J3G6tXResaMytQnIlImZvaSu7cWKzfc0VOS6+abo7VehItIndA0IiIiUjIlDRERKZmSRvDEyxv43gslDVMWEalbShrB07/exH+88F61myEiUtP0IjxIpxJ0dPcN76Svfa08jRERqVFKGkEmGWdfV+/wTvrwh8vTGBGRGqXHU0E6eRg9jZ//PFpEROqEehpBJhVnX3cv7k74Daji7rgjWus7DRGpE+ppBOlkAnfo7OmvdlNERGqWkkaQScUB2Nc9zPcaIiJ1REkjSCejJ3UdXcN8ryEiUkeUNIIR6mmIiBSlF+HBQE9jOEnj3nvL1BoRkdqkpBEMvNMYzuMpTYkuInVGj6eCbE9jWB/4PftstIiI1An1NIJMNmkM5wO/r341WusX/ESkTqinEaTD46lhvdMQEakzShrBQE9DQ25FRIakpBE0NsQwU09DRKQQJY3AzMgkE+ppiIgUoBfhOdLJ+PB6Gt/5TvkaIyJSg0rqaZhZs5k9bma/MbO1ZnahmY0xs+Vm9lZYjw5lzczuM7M2M3vFzM7Luc78UP4tM5ufEz/fzNaEc+6zMM3sUHWUSyaVGN7oqTPOiBYRkTpR6uOpbwE/cvczgXOAtcBtwAp3nwasCPsAc4FpYVkALIQoAQB3AhcAM4E7c5LAQuALOefNCfGh6iiLdDJOx3C+03jqqWgREakTRZOGmZ0AfBR4EMDdu919FzAPWBSKLQKuCtvzgEc88jzQbGYnA5cDy919h7vvBJYDc8KxUe7+vLs78Miga+WroywyycTw5p765jejRUSkTpTS05gKbAP+zcxeNrN/NbMMMN7dN4Uym4HxYXsCsD7n/A0hVii+IU+cAnUcxMwWmNkqM1u1bdu2Em4pv3QqPvxf7xMRqSOlJI0EcB6w0N3PBfYx6DFR6CH40W9eaXW4+wPu3ururS0tLYddRzR6SkNuRUSGUkrS2ABscPcXwv7jRElkS3i0RFhvDcc3ApNyzp8YYoXiE/PEKVBHWUSjp9TTEBEZStGk4e6bgfVmlh0mdBnwOrAEyI6Amg88GbaXANeFUVSzgN3hEdMyYLaZjQ4vwGcDy8KxdjObFUZNXTfoWvnqKItMSj0NEZFCSv1O4y+A75lZElgHXE+UcBab2Q3Au8DVoexS4AqgDegIZXH3HWZ2F/BiKPcVd98Rtm8EHgaagGfCAvD1Ieooi2xPw90Jo34L++53y9kcEZGaU1LScPfVQGueQ5flKevATUNc5yHgoTzxVcDZeeLb89VRLplUgt5+p7uvn1QiXvyESZOKlxEROY5oGpEc6WSY6bbUqUQeeyxaRETqhKYRyZFJZX9To5fRmWTxExYujNbXXFPGVomI1A71NHJkBn4nXCOoRETyUdLIkR74nXCNoBIRyUdJI4d6GiIihSlp5Mi+CFdPQ0QkP70Iz5F9EV5yT+Pxx8vYGhGR2qOkkSOT7WmUOtPtuHFlbI2ISO3R46kc6eyQ21IfTz38cLSIiNQJJY0cTQ3ZdxolPp5S0hCROqOkkSMeM5oahvk74SIidURJY5BMKj683wkXEakjShqDpJOJ4f1OuIhIHVHSGCSdVE9DRGQoGnI7SCaVKP2dxtKl5W2MiEiNUdIYJJ2Ms6ezxKSRTpe3MSIiNUaPpwbJJIfR07j//mgREakTShqDpFPx0r/TWLw4WkRE6oSSxiDD6mmIiNQZJY1B0vpOQ0RkSEoag2SSCbp7++np6692U0REak5JScPM3jGzNWa22sxWhdgYM1tuZm+F9egQNzO7z8zazOwVMzsv5zrzQ/m3zGx+Tvz8cP22cK4VqqOchj09uohIHRlOT+Nj7j7D3VvD/m3ACnefBqwI+wBzgWlhWQAshCgBAHcCFwAzgTtzksBC4As5580pUkfZZKdHL+m9xsqV0SIiUieO5PHUPGBR2F4EXJUTf8QjzwPNZnYycDmw3N13uPtOYDkwJxwb5e7Pu7sDjwy6Vr46yubA9OjqaYiIDFZq0nDgv83sJTNbEGLj3X1T2N4MjA/bE4D1OeduCLFC8Q154oXqKJth9TT+8R+jRUSkTpT6RfhH3H2jmZ0ILDez3+QedHc3Mz/6zSutjpDIFgBMnjz5iOpJJ4fR03j66Wh9661HVKeIyLGipJ6Gu28M663AE0TvJLaER0uE9dZQfCMwKef0iSFWKD4xT5wCdQxu3wPu3ururS0tLaXc0pAyqWH0NERE6kzRpGFmGTMbmd0GZgOvAkuA7Aio+cCTYXsJcF0YRTUL2B0eMS0DZpvZ6PACfDawLBxrN7NZYdTUdYOula+OshnoaWj0lIjIIUp5PDUeeCKMgk0A/+HuPzKzF4HFZnYD8C5wdSi/FLgCaAM6gOsB3H2Hmd0FvBjKfcXdd4TtG4GHgSbgmbAAfH2IOspmoKeh39QQETlE0aTh7uuAc/LEtwOX5Yk7cNMQ13oIeChPfBVwdql1lNOwehpNTWVujYhIbdHU6IOkk8PoaTzzTPEyIiLHEU0jMkhDPEYyEdM7DRGRPJQ08sgk4+wrpadx113RIiJSJ5Q08kgnE+wrZcjtihXRIiJSJ5Q08sik4nRoGhERkUMoaeRRck9DRKTOKGnkkUnFNTW6iEgeGnKbRzqZYPvejuIFx44tf2NERGqIkkYemWSJPY0f/KD8jRERqSF6PJVHOpXQhIUiInkoaeQRfadRQk/j9tujRUSkTujxVB6ZVIL9PX309TvxmA1d8Be/qFyjRERqgHoaeWTCpIX7ezSCSkQkl5JGHmlNjy4ikpeSRh4Z/RCTiEheeqeRR3Z69KKTFk6cWPi4iMhxRkkjj0wq+j9L0W81/v3fK9AaEZHaocdTeQz0NPSthojIQZQ08hjoaRT7VuPmm6NFRKRO6PFUHiX3NFavrkBrRERqh3oaeWRHT2nIrYjIwZQ08sh+p6EhtyIiBys5aZhZ3MxeNrOnw/5UM3vBzNrM7DEzS4Z4Kuy3heNTcq5xe4i/YWaX58TnhFibmd2WE89bR7kl4zESMdOkhSIigwynp/ElYG3O/jeAe9z9dGAncEOI3wDsDPF7QjnM7CzgWuBDwBzg/pCI4sC3gbnAWcBnQ9lCdZSVmZEuZdLC6dOjRUSkTpSUNMxsIvBJ4F/DvgGXAo+HIouAq8L2vLBPOH5ZKD8PeNTdu9z9t0AbMDMsbe6+zt27gUeBeUXqKLtMKdOjP/BAtIiI1IlSexr3An8N9If9scAud8/+V3UDMCFsTwDWA4Tju0P5gfigc4aKF6rjIGa2wMxWmdmqbdu2lXhLhZXU0xARqTNFk4aZfQrY6u4vVaA9h8XdH3D3VndvbWlpOSrXzKQSxYfcLlgQLSIidaKU7zQuAj5tZlcAjcAo4FtAs5klQk9gIrAxlN8ITAI2mFkCOAHYnhPPyj0nX3x7gTrKLp2MF/+47803K9MYEZEaUbSn4e63u/tEd59C9CL7x+7+J8BzwGdCsfnAk2F7SdgnHP+xu3uIXxtGV00FpgG/BF4EpoWRUslQx5JwzlB1lF0mWUJPQ0SkzhzJdxp/A/yVmbURvX94MMQfBMaG+F8BtwG4+2vAYuB14EfATe7eF3oRfw4sIxqdtTiULVRH2UW/E653GiIiuYY1jYi7rwRWhu11RCOfBpfpBP54iPPvBu7OE18KLM0Tz1tHJUS/E66ehohILs09NYRMKT2NGTMq0xgRkRqhpDGETDLOvu5e3J3ok5E87r23so0SEakyzT01hHQqgTt09vQXLywiUieUNIaQKWV69M99LlpEROqEHk8NIZ3M+SGmEUMU2rChcg0SEakB6mkMIZPST76KiAympDGEgZ6GkoaIyAAljSEM9DQ0aaGIyAC90xhCST2NCy+sUGtERGqDksYQsr8TXrCn8Q//UKHWiIjUBj2eGkL2d8L1TkNE5AAljSEM9DQKTSXyR38ULSIidUKPp4bQ2BDDDDoKTVq4fXvlGiQiUgPU0xiCmYXf1NDoKRGRLCWNAtLJuN5piIjkUNIoIJNK6DsNEZEceqdRQNGexmWXVa4xIiI1QEmjgEyySE/j7/6uco0REakBejxVQDoV14SFIiI5lDQKiHoaBZLG3LnRIiJSJ/R4qoBMKl74d8L3769cY0REakDRnoaZNZrZL83s12b2mpl9OcSnmtkLZtZmZo+ZWTLEU2G/LRyfknOt20P8DTO7PCc+J8TazOy2nHjeOiolXaynISJSZ0p5PNUFXOru5wAzgDlmNgv4BnCPu58O7ARuCOVvAHaG+D2hHGZ2FnAt8CFgDnC/mcXNLA58G5gLnAV8NpSlQB0Vke1puHslqxURqVlFk4ZH9obdhrA4cCnweIgvAq4K2/PCPuH4ZWZmIf6ou3e5+2+BNmBmWNrcfZ27dwOPAvPCOUPVURHpZILefqe7r7+S1YqI1KyS3mmE3sBLwOlEvYK3gV3unn12swGYELYnAOsB3L3XzHYDY0P8+ZzL5p6zflD8gnDOUHVURCYZZrrt6iOViB9a4FOfqmRzRESqrqSk4e59wAwzawaeAM4sa6uGycwWAAsAJk+efNSum05lZ7rtZXQmz+uUW289anWJiBwLhjXk1t13Ac8BFwLNZpZNOhOBjWF7IzAJIBw/AdieGx90zlDx7QXqGNyuB9y91d1bW1pahnNLBWUGfr1PU4mIiEBpo6daQg8DM2sCPgGsJUoenwnF5gNPhu0lYZ9w/McevUleAlwbRldNBaYBvwReBKaFkVJJopflS8I5Q9VREemB3wkfYgTVJZdEi4hInSjl8dTJwKLwXiMGLHb3p83sdeBRM/sq8DLwYCj/IPBdM2sDdhAlAdz9NTNbDLwO9AI3hcdemNmfA8uAOPCQu78WrvU3Q9RREeppiIgcrGjScPdXgHPzxNcRjXwaHO8E/niIa90N3J0nvhRYWmodlZJOFulpiIjUGU0jUkAmpZ6GiEguJY0CskNuNWmhiEhEc08VkB1y2zHU9OhXX13B1oiIVJ+SRgFNDUV6GjfeWMHWiIhUnx5PFRCPGU0NBWa67eiIFhGROqGeRhGZVHzo0VNXXBGtV66sWHtERKpJPY0i0smERk+JiARKGkWkkwV6GiIidUZJo4hMSj0NEZEsJY0i0sm4vtMQEQn0IryITDLB1vau/Ac///mKtkVEpNqUNIrIpBLsHeqdhpKGiNQZPZ4qIvqd8CGSxvvvR4uISJ1QT6OIdDLBvqFehH8m/NSHvtMQkTqhnkYRmWSc7t5+evr6q90UEZGqU9IoIq3p0UVEBihpFJHRDzGJiAxQ0ihi8pg0AGs3tVe5JSIi1acX4UWcP2U0TQ1xfvLmNi774PiDD/7Zn1WnUSIiVaKkUUQqEefDp43lJ29uO/TgNddUvkEiIlWkx1Ml+Oj0Ft7d3sE77+87+MD69dEiIlInlDRKcPH0FoBDext/+qfRIiJSJ4omDTObZGbPmdnrZvaamX0pxMeY2XIzeyusR4e4mdl9ZtZmZq+Y2Xk515ofyr9lZvNz4ueb2Zpwzn1mZoXqqLQp4zKcMjbNT/M9ohIRqSOl9DR6gVvc/SxgFnCTmZ0F3AascPdpwIqwDzAXmBaWBcBCiBIAcCdwATATuDMnCSwEvpBz3pwQH6qOirt4egs/f3s7Xb36XkNE6lfRpOHum9z9V2F7D7AWmADMAxaFYouAq8L2POARjzwPNJvZycDlwHJ33+HuO4HlwJxwbJS7P+/uDjwy6Fr56qi4i6e3sL+nj1Xv7KxWE0REqm5Y7zTMbApwLvACMN7dN4VDm4HseNQJQO7b4Q0hVii+IU+cAnUMbtcCM1tlZqu2bSvPI6RZp44lGY/lH0UlIlInSk4aZjYC+AFws7sf9KVb6CH4UW7bQQrV4e4PuHuru7e2tLSUpf5MKsEfTB3NT97ISRq33BItIiJ1oqSkYWYNRAnje+7+wxDeEh4tEdZbQ3wjMCnn9IkhVig+MU+8UB1VcfH0Ft7YsodNu/dHgSuvjBYRkTpRyugpAx4E1rr7P+UcWgJkR0DNB57MiV8XRlHNAnaHR0zLgNlmNjq8AJ8NLAvH2s1sVqjrukHXyldHVVw8/USAA6Oo3ngjWkRE6kQpX4RfBPwpsMbMVofYHcDXgcVmdgPwLnB1OLYUuAJoAzqA6wHcfYeZ3QW8GMp9xd13hO0bgYeBJuCZsFCgjqqYPn4EJ41q5CdvbuOaP5gMX/xidEC/pyEidaJo0nD3nwE2xOHL8pR34KYhrvUQ8FCe+Crg7Dzx7fnqqBYz4+LpLSx9dRO9ff2ag0VE6o6+CB+mi89oYU9nL6vX76p2U0REKk5JY5guOn0c8Zhp6K2I1CUljWE6oamBcyc1K2mISF3SY/nD8NHpLdzz7Ju03/LXjGpKVrs5IiIVo57GYbh4egvu8Nykc+DjH692c0REKkZJ4zD83oQTGJNJ8tayn8Hq1cVPEBE5TihpHIZYzPjDaeP42L98Db/55mo3R0SkYpQ0DtPF01vo6euno6u32k0REakYJY3D9IfTookRt+3trnJLREQqR0njMLWMTNEyMsWW9k7atu6tdnNERCpCSeMITB6TJmbw90teI5o9RUTk+KbvNI5Awze+ztZXfsfP2t5n2WubmXP2ydVukohIWamncSQ+/GEu/8L/5syTRnLX02vZ363fDxeR45uSxpH4+c9JvPA8X/70h9i4az8LV7ZVu0UiImWlx1NH4o47ALhg5UrmzfgA//LTdXzm/ElMHpuucsNERMpDPY2j5I4rPkhDzPjK069XuykiImWjpHGUjB/VyF9eNo1n127hud9U9afMRUTKRknjKLr+oqmc2pLhy0+9RlevXoqLyPFHSeMoSiZi/P2VH+Kd7R088JN11W6OiMhRpxfhR+Leew8JfXR6C5/8/ZP55vI3GdGY4PqLplahYSIi5aGkcSRmzMgb/uYfn0NvXz9ffup13t/bxa2zz8DMKtw4EZGjr+jjKTN7yMy2mtmrObExZrbczN4K69EhbmZ2n5m1mdkrZnZezjnzQ/m3zGx+Tvx8M1sTzrnPwn9dh6qjpjz7bLQM0tgQ5/4/OZ/PzpzEt597m9t/uIbevv4qNFBE5Ogq5Z3Gw8CcQbHbgBXuPg1YEfYB5gLTwrIAWAhRAgDuBC4AZgJ35iSBhcAXcs6bU6SO2vHVr0ZLHvGY8bX/9Xv85aWn8+iL67nxe7+is0cvx0Xk2FY0abj7T4Edg8LzgEVhexFwVU78EY88DzSb2cnA5cByd9/h7juB5cCccGyUuz/v0Yx/jwy6Vr46jhlmxl/NPoMvf/pDLF+7hese+iW79/dUu1kiIoftcEdPjXf3TWF7MzA+bE8A1ueU2xBiheIb8sQL1XHMmf/hKdx37bm8/N5OrvnOL1izYXe1myQicliOeMht6CGUdV7wYnWY2QIzW2Vmq7Zt21bOphy2K8/5AP/2+Zm8v7eLT3/7Z9yy+Ndsae+sdrNERIblcJPGlvBoibDOfgK9EZiUU25iiBWKT8wTL1THIdz9AXdvdffWlpaWw7yl8vvItHE8d+slfPGjp/HUr3/HJf93Jd969i3Njisix4zDTRpLgOwIqPnAkznx68IoqlnA7vCIaRkw28xGhxfgs4Fl4Vi7mc0Ko6auG3StfHXUju98J1qGYWRjA7fNPZMVt1zMpWeeyD3Pvsml31zJEy9voL9fP+QkIrXNiv3inJl9H7gEGAdsIRoF9Z/AYmAy8C5wtbvvCP/h/2eiEVAdwPXuvipc5/8Ad4TL3u3u/xbirUQjtJqAZ4C/cHc3s7H56ih2Q62trb5q1apS77/qfvnbHdz19Ous2bibU1syXH/RVP7ovAmkk/qERkQqx8xecvfWouWOt58prWjSeOqpaH3llUd0mf5+5+k1m3jwf9bx6w27GdWY4LMXTOa6C6cwobnpKDRURKQwJY1KuOSSaL1y5VG5nLvzq/d28dD/+y0/enUzAHPOPonPXXAKM6eOIR7TV+UiUh6lJg09A6khZsb5p4zm/FNGs3HXfh75xTt8/4X3+K9XNjFuRIrLPzSeuWefzKxTx5CIa65JEak8JY0aNaG5idvnfpAvXTaNFWu38qNXN/PDX23key+8x+h0A584azxzzj6JWaeO1fsPEakY/demxqWTCa485wNcec4H2N/dx0/e3MaPXt3E0jWbWbxqAw1xY8akZi48dSwXnjaOcyc309gQr3azReQ4paRxDGlKxplz9knMOfskunr7eGHdDn7+9nZ+sW47//xcG/f9uI1UIsb5p4ymdcoYzp3UzDmTmhmTSVa76SJynNCL8COxPsyMMmlS4XIV0N7Zw4u/DUnk7e38ZnM72c8+ThmbZsakZs6ZGCWRM08aSSalfy+IyAEaPVXn9nX1smbjblav38Xq93axev0uNodpS8xg6tgMH/zAKM46eRRnhfWJI1P63Q+ROqXRU5Xw2GPR+pprqtuOPDKpBLNOHcusU8cOxDbt3s+rG9t5/XftvL5pN69s2MV/vbJp4PjIxgSntYzgtJYRnH7iCE5ryXDaiSOYPCZNg0ZriQjqaRyZo/ydRjW0d/bwm017WLupnbe37aVt617e3raXLe1dA2XiMeMDzY1MGZvhlLFppozNMHlMmslj00xobmJkY0MV70BEjgb1NKQkoxobmDl1DDOnjjko3t7Zw7pt+2jbupd3t+/jne0dvLd9H0tW/472zt5B10gwYXSUQCY0NzJhdBMnndDEySc0ctKoRsaPaiSZUE9F5HigpCF5jWpsYMakZmZMaj7k2K6Obt7d3sF7OzrYuGs/G3fuZ+Ou/azf0cHz67azt6v3kHPGjUgyPiSQlhEpWkYevIwbkWLciCQjUgm9VxGpYUoaMmzN6STN6STn5EkoALv397ClvZNNuzvZsjtab27vZPPu/Wxp7+TVjbvZvq+bvjyz+qYSMcaNSDF2RDJaZ5KMGZFkTDrJmMzBS3M6yahGJRmRSlLSkKPuhKYGTmhqYPr4kUOW6et3dnZ0s21PF+/v7RpYb9/bzbaw3tLeyeu/a2fHvm66+/rzXiceM05oaqC5qYHmdEOU0JoaGBWWEwYto5oSjGqMjmWScSUckWFS0jgSjz9e7RYcs+IxC4+kUkXLujv7uvvYsbebHR3d7NgXJZXd+3vY1dHDrv3d7OzoYXdHD1v3dPLmlj3s3t/Dns5DH5Pliln0+yajmhKMTDUwojHBqMYEI1IJRjQmGNnYwIhUgpGNCTLJEAvHMqmoXCaVIN0QJ6bJJKVOKGkciXHjqt2CumBm0X/IUwkmj02XfF5fv7Ons4fd+w8sezp7ad/fQ3tnD+37e8O6h71dvezp7OV3uzrZ29Ub9nvo6SttdGEmGR9IJOlUnEwyJJRktJ2NNSXjNDXESSfjOdsJmpIxGhui/aZknMZEtE4lYuoNSU1R0jgSDz8crT//+Wq2QoYQj9nA+5fD1dnTx76BJBKt92bXXb10dPeytysqky23v7uPvV29bNvTxb7uXjq6+tjXHR0/nB9nbGyI0dQQp/GgJTcWozERpzEkm1RDjFQiRioRJ5nIbsfCdjgej5FqiJGMHyg/cDyU10zKko+SxpFQ0jjuZf8jPbaEx2jFuDvdff3s7+6jIyzRdi+dvVG8qzeKdfb00dHTR2dPP109fezvCfFQrrMnWto7e+jsOXBuZ08/Xb19JfeQCokZJBMxkvEYyUScZNyi/ewSP5BosrFU/NDjDQPlou2GgbgdVCaK28B2QzzaTmS3YzEaEjESseg8PRKsDiUNkQoxs/Av+TjNpT9lOyz9/VGC6gpJpKu3Pyx9dIft7jyxrp6+nPP66e47UK4nbHfnxLt7+9nV0T1QtqsnlMs53ns43asSxGM2kFiS8RiJuJGIRQkmEY+SS0OIN8RixGMWbcdjA+fGYzEaQjwRz25H5yZyjsfjRiIW7UdrC2Xy78didlA8Wh9oQ974QfVE8ZhRc48nlTREjkOxmNEYi4dp8qv7xX42gWWTTk+fH5RYegaO+cB2dNzpHdiPjvX2HbhWb851stfu7Xd6+5ze/uic3r4o1tPXT29/P/vnk7tcAAAEq0lEQVR7nL6w39fvB47lOae33/MOC6+0Q5JPPEbMDuznLg/N/4Nhvfc7rPaU9eoiUvcOTmDHFncfSC7ZdW9OwsmN9fQ5/Z6NR4moL3t+3+Dr9B+IZ8/JSVZRHUTX6Q/X6QvH+qNj/aFsts7+fifVUP73UEoaIiJDMMs+uqp2S2qHksaRWLq02i0QEamomh9TZ2ZzzOwNM2szs9uq3Z6DpNPRIiJSJ2o6aZhZHPg2MBc4C/ismZ1V3VbluP/+aBERqRM1nTSAmUCbu69z927gUWBeldt0wOLF0SIiUidqPWlMANbn7G8IsYOY2QIzW2Vmq7Zt21axxomI1JtaTxolcfcH3L3V3VtbWlqq3RwRkeNWrSeNjcCknP2JISYiIlVQ60njRWCamU01syRwLbCkym0SEalb5l79z+QLMbMrgHuBOPCQu99dpPw24N3DrG4c8P5hnnusq+d7h/q+/3q+d6jv+8+991Pcvejz/ZpPGpVkZqvcvbXa7aiGer53qO/7r+d7h/q+/8O591p/PCUiIjVESUNEREqmpHGwB6rdgCqq53uH+r7/er53qO/7H/a9652GiIiUTD0NEREpmZJGUNOz6R5lZvaQmW01s1dzYmPMbLmZvRXWo6vZxnIxs0lm9pyZvW5mr5nZl0K8Xu6/0cx+aWa/Dvf/5RCfamYvhL//x8J3UcclM4ub2ctm9nTYr6d7f8fM1pjZajNbFWLD+ttX0uAYmE336HsYmDModhuwwt2nASvC/vGoF7jF3c8CZgE3hf9f18v9dwGXuvs5wAxgjpnNAr4B3OPupwM7gRuq2MZy+xKwNme/nu4d4GPuPiNnqO2w/vaVNCK1PZvuUebuPwV2DArPAxaF7UXAVRVtVIW4+yZ3/1XY3kP0H48J1M/9u7vvDbsNYXHgUuDxED9u79/MJgKfBP417Bt1cu8FDOtvX0kjUtJsuse58e6+KWxvBsZXszGVYGZTgHOBF6ij+w+PZ1YDW4HlwNvALnfvDUWO57//e4G/BvrD/ljq594h+gfCf5vZS2a2IMSG9bevn3uVQ7i7m9lxPazOzEYAPwBudvf26B+ckeP9/t29D5hhZs3AE8CZVW5SRZjZp4Ct7v6SmV1S7fZUyUfcfaOZnQgsN7Pf5B4s5W9fPY2IZtOFLWZ2MkBYb61ye8rGzBqIEsb33P2HIVw395/l7ruA54ALgWYzy/4j8nj9+78I+LSZvUP0CPpS4FvUx70D4O4bw3or0T8YZjLMv30ljYhm043ud37Yng88WcW2lE14hv0gsNbd/ynnUL3cf0voYWBmTcAniN7rPAd8JhQ7Lu/f3W9394nuPoXof+M/dvc/oQ7uHcDMMmY2MrsNzAZeZZh/+/q4LxjubLrHMjP7PnAJ0QyXW4A7gf8EFgOTiWYJvtrdB78sP+aZ2UeA/wHWcOC59h1E7zXq4f5/n+hlZ5zoH42L3f0rZnYq0b++xwAvA59z967qtbS8wuOpW939U/Vy7+E+nwi7CeA/3P1uMxvLMP72lTRERKRkejwlIiIlU9IQEZGSKWmIiEjJlDRERKRkShoiIlIyJQ0RESmZkoaIiJRMSUNEREr2/wEpvU3CBmUuLAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(cnt)\n", "plt.axvline(x=2, color=\"red\", linestyle=\"--\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## Numericalize the text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We trained the language model based on 224,472 texts and validated on 31,609 texts from Malay Wikipedia using various tricks detailed in the [ULMFiT paper](https://arxiv.org/abs/1801.06146). The validation perplexity is 29.30245." ] }, { "cell_type": "code", "execution_count": 167, "metadata": {}, "outputs": [], "source": [ "# Truncating our vocab to ignore the rare words\n", "max_vocab = 60000\n", "min_freq = 5" ] }, { "cell_type": "code", "execution_count": 168, "metadata": {}, "outputs": [], "source": [ "itos = [o for o, c in freq.most_common(max_vocab) if c > min_freq] # getting rid of the rare words\n", "itos.insert(0, \"_pad_\")\n", "itos.insert(0, \"_unk_\") # itos is the list of all the strings in the vocab" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "60002" ] }, "execution_count": 169, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating a index-key dictionary for our vocabulary\n", "stoi = collections.defaultdict(lambda: 0, {v:k for k, v in enumerate(itos)})\n", "len(itos)" ] }, { "cell_type": "code", "execution_count": 170, "metadata": {}, "outputs": [], "source": [ "# Creating a index representation for our train and validation dataset\n", "trn_lm = np.array([[stoi[o] for o in p] for p in tok_trn])\n", "val_lm = np.array([[stoi[o] for o in p] for p in tok_val])" ] }, { "cell_type": "code", "execution_count": 171, "metadata": {}, "outputs": [], "source": [ "# Saving our indexed representation of our dataset to disk.\n", "# We also save the index-word mapping to retrieve the complete text representation from these numpy arrays\n", "np.save(LM_PATH / \"tmp\" / \"trn_ids.npy\", trn_lm)\n", "np.save(LM_PATH / \"tmp\" / \"val_ids.npy\", val_lm)\n", "pickle.dump(itos, open(LM_PATH / \"tmp\" / \"itos.pkl\", \"wb\"))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Loading the indexed representation of our dataset from disk.\n", "# We also load the index-word mapping to to help us convert the indexes to word datasets, if need be.\n", "trn_lm = np.load(LM_PATH / \"tmp\" / \"trn_ids.npy\")\n", "val_lm = np.load(LM_PATH / \"tmp\" / \"val_ids.npy\")\n", "itos = pickle.load(open(LM_PATH / \"tmp\" / \"itos.pkl\", \"rb\"))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60002, 224472, 31609)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check vocabulary size\n", "vocab_size = len(itos)\n", "trn_set_size = len(trn_lm)\n", "val_set_size = len(val_lm)\n", "vocab_size, trn_set_size, val_set_size" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([list([9, 13, 8, 7, 0, 17, 19, 58, 23, 5, 43, 33863, 2, 3841, 11, 7682, 2, 146, 3, 24, 8, 7, 1675]),\n", " list([9, 13, 8, 7, 700, 17, 25, 298, 4, 225, 12, 119, 2533, 14, 978, 1386, 3, 9, 687, 1835, 2184, 669, 1215, 648, 24, 8, 7, 1587])],\n", " dtype=object)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sanity check\n", "trn_lm[:2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Language Model" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "em_sz = 400 # size of each embedding vector\n", "nh = 1150 # number of hidden activations per layer\n", "nl = 3 # number of layers\n", "\n", "wd = 1e-7\n", "bptt = 70\n", "bs = 64\n", "# opt_fn = partial(optim.Adam, betas=(0.8, 0.99))\n", "opt_fn = partial(optim.SGD, momentum=0.9)\n", "weight_factor = 0.3" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# if you're overfitting, increase this. Underfitting? decrease this.\n", "drops = np.array([0.25, 0.1, 0.2, 0.02, 0.15]) * weight_factor" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 0 ns, sys: 0 ns, total: 0 ns\n", "Wall time: 6.44 µs\n" ] } ], "source": [ "%time\n", "trn_dl = LanguageModelLoader(np.concatenate(trn_lm), bs, bptt)\n", "val_dl = LanguageModelLoader(np.concatenate(val_lm), bs, bptt)\n", "md = LanguageModelData(DATA_PATH, pad_idx=1, n_tok=vocab_size, trn_dl=trn_dl, val_dl=val_dl, bs=bs, bptt=bptt)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "learner= md.get_model(opt_fn, em_sz, nh, nl, \n", " dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])\n", "\n", "learner.metrics = [accuracy]\n", "learner.clip = 0.2\n", "learner.unfreeze()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "collapsed": true }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# AWD-LSTM network\n", "learner.summary" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Hyper-parameter search**\n", "\n", "Properly setting the hyper-parameters of a neural network can be challenging, thankfully, there are some technique that can help.\n", "\n", "We will speed up training using Leslie Smith's work on 1cycle policy that he described the [super-convergence] phenomenon in this [paper](https://arxiv.org/abs/1708.07120). Here's an [application of super-convergence to win the DAWNBench challenge](http://www.fast.ai/2018/04/30/dawnbench-fastai/).\n", "\n", "Based on my own experiments with this method previously, the AWS-LSTM model converged faster, instead of 15 epochs, now it takes just 10 epochs.\n", "\n", "We will be using an [implementation of this method](http://forums.fast.ai/t/the-1cycle-policy-an-experiment-that-investigate-super-convergence-phenomenon-described-in-leslie-smiths-research/14737):\n", "- A simple guide on how to use 1cycle policy [Cyclical Learning Rate (CLR)](http://forums.fast.ai/t/using-use-clr-beta-and-new-plotting-tools/14702)\n", "- Some [tips on super-convergence(ish) on WikiText-2](http://forums.fast.ai/t/super-convergence-ish-on-wikitext-2/17091), a similar LM task like ours." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ec21c7f8b3654333a6ed135a80eedc15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 9%|▉ | 500/5559 [09:32<1:32:51, 1.10s/it, loss=7.94]" ] } ], "source": [ "# Find learning rate\n", "learner.lr_find2(num_it=500)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learner.sched.plot()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# Set learning rate\n", "lr = 8" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4f505918b0b34cb18bb9bbefb96e4d3d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 4.114716 3.936571 0.367859 \n", " 1 3.83864 3.711561 0.382893 \n", " 2 3.669321 3.603781 0.391633 \n", " 3 3.63252 3.560706 0.394518 \n", " 4 3.478959 3.513905 0.399009 \n", " 5 3.518267 3.480469 0.401523 \n", " 6 3.409158 3.465206 0.402808 \n", " 7 3.426483 3.437133 0.405097 \n", " 8 3.296175 3.409095 0.409595 \n", " 9 3.185208 3.377671 0.413643 \n" ] }, { "data": { "text/plain": [ "[array([3.37767]), 0.41364255045710613]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit(lr, 1, wds=wd, cycle_len=10, use_clr=(10,33,0.95,0.85), best_save_name='best_lm_malay_1cycle')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualize loss history and learning rate history" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learner.sched.plot_loss()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learner.sched.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Save weights" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "learner.save('lm_malay_final')" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "learner.save_encoder('lm_malay_enc_final')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how our model did. Losses after 10 epochs:\n", "- Training: 3.185208\n", "- Validation: **3.377671**" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "29.30244618263257" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.exp(3.377671)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Analysis**:\n", "\n", "The **perplexity** of this language model after 10 epochs is **29.30245** (accuracy of 41.3643%).\n", "\n", "It took me ~1 hour 24 minutes (5016.09s) to train 1 epoch on one Tesla K80 GPU, roughly 1.1 iteration/s. The full training took me ~14 hours." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate text using the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We perform eyeballing test by having the model \"fill in the blanks\"." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# Load model from saved weights\n", "learner.load('lm_malay_final')" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "m = learner.model # initialized model\n", "m.eval() # ensure you switch your model to inference mode\n", "m[0].bs = 1 # set batch size to 1" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "60002" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the index-word mapping from the picked file on disk to help us convert the indexes to words.\n", "itos = pickle.load(open(LM_PATH / \"tmp\" / \"itos.pkl\", \"rb\"))\n", "\n", "# String to index lambda function\n", "stoi = collections.defaultdict(lambda: 0, {v:k for k, v in enumerate(itos)})\n", "len(itos)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "# Utility functions\n", "\n", "def gen_text(ss, topk):\n", " s = ss.strip().split(\" \")\n", " t = LongTensor([stoi[i] for i in s]).view(-1, 1).cuda()\n", " t = Variable(t, volatile=False)\n", " m.reset()\n", " pred, *_ = m(t)\n", " pred_i = torch.topk(pred[-1], topk)[1]\n", "\n", " return [itos[o] for o in to_np(pred_i)]\n", "\n", "def generate_sentences(ss, nb_words):\n", " result = []\n", " s = ss.strip().split(\" \")\n", " t = LongTensor([stoi[i] for i in s]).view(-1, 1).cuda()\n", " t = Variable(t, volatile=False)\n", " m.reset()\n", " pred, *_ = m(t)\n", " \n", " for i in range(nb_words):\n", " pred_i = pred[-1].topk(2)[1]\n", " pred_i = pred_i[1] if pred_i.data[0] < 2 else pred_i[0]\n", " word = itos[pred_i.data[0]]\n", " if word != \"xbos\":\n", " result.append(word)\n", " else:\n", " break\n", " pred, *_ = m(pred_i[0].unsqueeze(0))\n", " \n", " result = re.sub('\\s+([.,])', r'\\1', \"{} {}\".format(ss, \" \".join(result).rstrip()))\n", "\n", " return(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Inference" ] }, { "cell_type": "code", "execution_count": 138, "metadata": {}, "outputs": [], "source": [ "strings = [\n", " \"Menara Petronas\",\n", " \"Dr Mahathir merupakan Perdana Menteri\",\n", " \"Tunku ialah Bapa Kemerdekaan\",\n", " \"Syarikat penerbangan\",\n", " \"Durian ialah buah\",\n", " \"P Ramlee ialah seorang\",\n", " \"Pemenang badminton Piala Thomas\",\n", " \"Lee Chong Wei dan badminton\",\n", " \"Jurulatih Rashid Sidek\",\n", " \"Pokok getah\",\n", " \"Industri kelapa sawit di Malaysia\",\n", " \"Penyelidikan minyak sawit\",\n", " \"Negara terbesar di Asia Tenggara ialah\",\n", " \"Proton Saga adalah\",\n", " \"Penyanyi terkenal\"\n", "]" ] }, { "cell_type": "code", "execution_count": 139, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Menara Petronas... \n", "menara petronas. \n", " pada tahun 2005, sebuah syarikat yang dikenali sebagai t_up petronas, yang dimiliki oleh t_up petronas, telah membeli sebuah syarikat yang dimiliki oleh t_up petronas, t_up petronas. pada tahun 2008, t_up petronas membeli saham mereka daripada t_up petronas, dan t_up petronas, dan t_up petronas, dan t_up petronas, dan t_up petronas, dan t_up petronas, dan t_up petronas, yang telah membeli saham mereka pada tahun 2008. t_up\n", "\n", "Dr Mahathir merupakan Perdana Menteri... \n", "dr mahathir merupakan perdana menteri malaysia yang pertama. beliau juga merupakan seorang ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, negeri sembilan. beliau juga\n", "\n", "Tunku ialah Bapa Kemerdekaan... \n", "tunku ialah bapa kemerdekaan malaysia. beliau juga merupakan seorang ahli politik yang pernah berkhidmat sebagai perdana menteri malaysia ke-6. beliau juga merupakan seorang ahli parlimen bagi kawasan bukit besi, malaysia. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, pulau pinang. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, pulau pinang. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, pulau pinang. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, pulau\n", "\n", "Syarikat penerbangan... \n", "syarikat penerbangan ini. \n", " pada tahun 2005, sebuah syarikat penerbangan antarabangsa, t_up bae systems, telah ditubuhkan. pada tahun 2005, syarikat penerbangan ini telah melancarkan operasi untuk melancarkan pesawat pejuang generasi baru, \" t_up atr 72 \". pada tahun 2009, syarikat penerbangan ini telah melancarkan operasi untuk melancarkan pesawat pejuang generasi baru, \" t_up atr 72 \". pada tahun 2014, syarikat penerbangan ini telah melancarkan operasi untuk melancarkan pesawat pejuang generasi\n", "\n", "Durian ialah buah... \n", "durian ialah buah - buahan yang ditanam di dalam hutan. \n", " xfld 1 1137.0\n", "\n", "P Ramlee ialah seorang... \n", "p ramlee ialah seorang ahli perniagaan yang berjaya. beliau juga merupakan seorang ahli perniagaan yang berjaya. beliau juga merupakan seorang ahli perniagaan yang berjaya. beliau juga merupakan seorang ahli perniagaan dan ahli perniagaan. beliau juga merupakan seorang ahli perniagaan dan ahli perniagaan. beliau juga merupakan seorang ahli perniagaan dan ahli perniagaan. beliau juga merupakan seorang ahli perniagaan dan ahli perniagaan. beliau juga merupakan seorang ahli perniagaan dan ahli perniagaan. beliau juga merupakan seorang ahli perniagaan dan\n", "\n", "Pemenang badminton Piala Thomas... \n", "pemenang badminton piala thomas, t_up pbb. \n", " pada tahun 2005, persatuan bola sepak malaysia ( t_up fam ) mengumumkan bahawa t_up fam akan menubuhkan persatuan bola sepak malaysia ( t_up fam ). pada tahun 2009, t_up fam mengumumkan bahawa t_up fam akan menyertai t_up aff pada tahun 2013. t_up fam juga telah mencadangkan bahawa t_up fam akan menyertai t_up aff pada tahun 2013. t_up fam juga telah bersetuju untuk menyertai t_up aff pada tahun 2014. t_up\n", "\n", "Lee Chong Wei dan badminton... \n", "lee chong wei dan badminton. \n", " pada tahun 2005, sebuah lagi acara sukan diadakan di stadium nasional bukit jalil, kuala lumpur. pada tahun 2009, persatuan bola sepak malaysia ( t_up fam ) telah mengadakan perlawanan persahabatan menentang t_up fam, t_up fam, t_up fam, t_up fam, t_up fam, t_up fam, t_up fam, t_up afc dan t_up fifa. pada tahun 2009, t_up fam mengumumkan bahawa mereka akan bertanding dalam piala malaysia pada tahun\n", "\n", "Jurulatih Rashid Sidek... \n", "jurulatih rashid sidek, seorang ahli politik yang telah berkhidmat sebagai perdana menteri malaysia. beliau juga merupakan seorang ahli parlimen bagi kawasan t_up dun t_up n53. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, negeri sembilan. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, negeri sembilan. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, negeri sembilan. beliau juga merupakan ahli parlimen bagi kawasan bukit gelugor, negeri sembilan. beliau juga merupakan\n", "\n", "Pokok getah... \n", "pokok getah. \n", " pada tahun 2011, sebuah lagi projek baru yang dikenali sebagai \" the new york times \" telah dilancarkan. pada tahun 2011, sebuah lagi projek baru yang dikenali sebagai \" the new york times \" telah dilancarkan. pada tahun 2013, sebuah siri \" the last airbender \" telah diterbitkan semula sebagai siri \" the last airbender \". pada tahun 2013, \" the new york times \" melaporkan bahawa \" the new york\n", "\n", "Industri kelapa sawit di Malaysia... \n", "industri kelapa sawit di malaysia. \n", " pada tahun 2005, sebuah syarikat yang dikenali sebagai t_up petronas, telah ditubuhkan untuk membangunkan dan memajukan t_up petronas sebagai sebuah syarikat yang bertanggungjawab untuk pembangunan dan pembangunan teknologi maklumat. t_up petronas telah ditubuhkan pada tahun 1990, dan pada tahun 1992, t_up petronas telah menjadi syarikat yang pertama untuk membangunkan dan menghasilkan produk - produk yang berkualiti. t_up petronas telah membeli dan membeli sebuah syarikat yang dikenali sebagai t_up petronas t_up petronas,\n", "\n", "Penyelidikan minyak sawit... \n", "penyelidikan minyak sawit, dan juga beberapa jenis produk yang berkaitan dengan industri. \n", " xfld 1 1137.0\n", "\n", "Negara terbesar di Asia Tenggara ialah... \n", "negara terbesar di asia tenggara ialah t_up unesco. \n", " sejarah. \n", " pada tahun 2011, sebuah lagi muzium sejarah di indonesia, iaitu muzium negara indonesia, telah dibuka di jakarta, indonesia. muzium ini telah dirasmikan oleh t_up yab perdana menteri malaysia, tun dr. mahathir bin mohamad pada 1 jun 2007. muzium ini mempamerkan koleksi seni bina yang unik dan menarik. muzium ini mempamerkan koleksi seni bina yang unik dan menarik. muzium ini mempamerkan koleksi seni bina\n", "\n", "Proton Saga adalah... \n", "proton saga adalah sebuah kereta kebal utama yang digunakan oleh tentera udara diraja malaysia. xfld 1 1137.0\n", "\n", "Penyanyi terkenal... \n", "penyanyi terkenal, penyanyi, penyanyi, penulis lagu, komposer, komposer, komposer, komposer, dan artis. \n", " lagu. \n", " lagu ini digubah oleh komposer terkenal, komposer terkenal, komposer terkenal, ahmad nawab. lagu ini digubah oleh ahmad nawab, dan liriknya ditulis oleh ahmad nawab. lagu ini digubah oleh ahmad nawab, dan liriknya ditulis oleh ahmad nawab. lagu ini digubah oleh ahmad nawab, dan dinyanyikan oleh p. ramlee.\n", "\n" ] } ], "source": [ "for s in strings:\n", " print(f\"{s}... \\n{generate_sentences(s.lower(), 80)}\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We extract the embedding layer of the encoder to be used in the same manner as word2vec. We can also create sentence vector by summing or averaging the vectors. For more details about word2vec use cases, see word2vec_examples.ipynb." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Extract" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [], "source": [ "emb_weights = list(learner.model.named_parameters())[0][1]\n", "emb_np = to_np(emb_weights.data)" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...390391392393394395396397398399
semua0.0130380.5153650.0506580.403895-0.076790-0.274187-0.294301-0.482214-0.942632-0.382407...0.0004830.049201-0.237344-0.1952820.318259-0.085800-0.082405-0.008426-0.548818-0.450443
kanak-0.2612320.1873450.054948-0.3889660.3476220.906619-0.085901-0.067504-0.2279470.385188...0.1337440.541352-0.1511850.1442950.4672010.0600890.7360180.3569200.567058-0.463647
album0.493681-0.060911-0.0299970.046737-0.5803960.530397-0.0357590.317195-0.0298520.295080...0.263408-0.2291090.182056-0.561614-0.144355-0.297672-0.0308240.4620250.1470160.053593
menteri-0.222996-0.0381490.061630-0.323144-0.1020950.199310-0.400978-0.040864-0.019013-0.120582...0.4126100.142238-0.3506350.059112-0.3009060.3964390.0354770.269761-0.146999-0.197764
turut-0.314148-0.3291030.0846000.258733-0.126275-0.426062-0.482648-0.1290920.2019550.143144...-0.096362-0.009566-0.6459490.033635-0.1200770.654878-0.3480520.138207-0.021835-0.045728
kira-0.989307-0.0557680.138539-0.709911-0.174012-0.861263-0.4912230.044874-0.4031550.109385...0.157908-0.0445330.267241-0.251010-0.141131-0.1712180.155060-0.2959760.203567-0.208005
murid0.248733-0.059398-0.023560-0.273422-0.2988700.082709-0.4143810.048601-0.3984220.156399...0.035820-0.1101920.2965250.256502-0.831308-0.100444-0.1418680.223650-0.9334870.428067
kuala0.151282-0.164024-0.214619-0.244270-0.234031-0.2989970.0219820.5464220.604748-0.260872...-0.3268030.072062-0.0171950.2665190.331591-0.0685820.1758150.1049830.1103670.049074
semula0.125103-0.1974200.751799-0.990424-0.5837070.293708-0.224855-0.778633-0.698253-0.438961...-0.2317470.0375550.303921-0.279731-0.9141020.3396980.3373180.274428-0.1032000.126072
perempuan0.3443720.385741-0.0956030.1613510.2593550.608875-0.228196-0.217302-0.5572340.143539...0.383469-0.1325930.0771430.672117-0.2788210.018481-0.1995140.8946650.1576680.372751
\n", "

10 rows × 400 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 \\\n", "semua 0.013038 0.515365 0.050658 0.403895 -0.076790 -0.274187 \n", "kanak -0.261232 0.187345 0.054948 -0.388966 0.347622 0.906619 \n", "album 0.493681 -0.060911 -0.029997 0.046737 -0.580396 0.530397 \n", "menteri -0.222996 -0.038149 0.061630 -0.323144 -0.102095 0.199310 \n", "turut -0.314148 -0.329103 0.084600 0.258733 -0.126275 -0.426062 \n", "kira -0.989307 -0.055768 0.138539 -0.709911 -0.174012 -0.861263 \n", "murid 0.248733 -0.059398 -0.023560 -0.273422 -0.298870 0.082709 \n", "kuala 0.151282 -0.164024 -0.214619 -0.244270 -0.234031 -0.298997 \n", "semula 0.125103 -0.197420 0.751799 -0.990424 -0.583707 0.293708 \n", "perempuan 0.344372 0.385741 -0.095603 0.161351 0.259355 0.608875 \n", "\n", " 6 7 8 9 ... 390 \\\n", "semua -0.294301 -0.482214 -0.942632 -0.382407 ... 0.000483 \n", "kanak -0.085901 -0.067504 -0.227947 0.385188 ... 0.133744 \n", "album -0.035759 0.317195 -0.029852 0.295080 ... 0.263408 \n", "menteri -0.400978 -0.040864 -0.019013 -0.120582 ... 0.412610 \n", "turut -0.482648 -0.129092 0.201955 0.143144 ... -0.096362 \n", "kira -0.491223 0.044874 -0.403155 0.109385 ... 0.157908 \n", "murid -0.414381 0.048601 -0.398422 0.156399 ... 0.035820 \n", "kuala 0.021982 0.546422 0.604748 -0.260872 ... -0.326803 \n", "semula -0.224855 -0.778633 -0.698253 -0.438961 ... -0.231747 \n", "perempuan -0.228196 -0.217302 -0.557234 0.143539 ... 0.383469 \n", "\n", " 391 392 393 394 395 396 \\\n", "semua 0.049201 -0.237344 -0.195282 0.318259 -0.085800 -0.082405 \n", "kanak 0.541352 -0.151185 0.144295 0.467201 0.060089 0.736018 \n", "album -0.229109 0.182056 -0.561614 -0.144355 -0.297672 -0.030824 \n", "menteri 0.142238 -0.350635 0.059112 -0.300906 0.396439 0.035477 \n", "turut -0.009566 -0.645949 0.033635 -0.120077 0.654878 -0.348052 \n", "kira -0.044533 0.267241 -0.251010 -0.141131 -0.171218 0.155060 \n", "murid -0.110192 0.296525 0.256502 -0.831308 -0.100444 -0.141868 \n", "kuala 0.072062 -0.017195 0.266519 0.331591 -0.068582 0.175815 \n", "semula 0.037555 0.303921 -0.279731 -0.914102 0.339698 0.337318 \n", "perempuan -0.132593 0.077143 0.672117 -0.278821 0.018481 -0.199514 \n", "\n", " 397 398 399 \n", "semua -0.008426 -0.548818 -0.450443 \n", "kanak 0.356920 0.567058 -0.463647 \n", "album 0.462025 0.147016 0.053593 \n", "menteri 0.269761 -0.146999 -0.197764 \n", "turut 0.138207 -0.021835 -0.045728 \n", "kira -0.295976 0.203567 -0.208005 \n", "murid 0.223650 -0.933487 0.428067 \n", "kuala 0.104983 0.110367 0.049074 \n", "semula 0.274428 -0.103200 0.126072 \n", "perempuan 0.894665 0.157668 0.372751 \n", "\n", "[10 rows x 400 columns]" ] }, "execution_count": 148, "metadata": {}, "output_type": "execute_result" } ], "source": [ "malay2vec = pd.DataFrame(emb_np)\n", "new_itos = itos\n", "# replace space with token\n", "new_itos[2] = '_space_'\n", "# replace space for named entities with _\n", "new_itos = [re.sub(' ', '_', i) for i in new_itos]\n", "malay2vec.index = new_itos\n", "malay2vec.iloc[200:210]" ] }, { "cell_type": "code", "execution_count": 151, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60002, 400)" ] }, "execution_count": 151, "metadata": {}, "output_type": "execute_result" } ], "source": [ "malay2vec.to_csv(f\"{LM_PATH}/malay2vec_embeddings.csv\", sep=\" \", header=False, line_terminator=\"\\n\")\n", "malay2vec.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from gensim.models import KeyedVectors\n", "\n", "# model = KeyedVectors.load_word2vec_format(f\"{LM_PATH}/malay2vec.csv\", binary=False,\n", "# unicode_errors = \"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# model.save_word2vec_format(f'{LM_PATH}/malay2vec.bin', None, True)" ] } ], "metadata": { "gist": { "data": { "description": "mywork/Telugu_Language_Model.ipynb", "public": false }, "id": "" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }