{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import * " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Wikitext 103" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()/'wikitext-103'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def istitle(line):\n", " return len(re.findall(r'^ = [^=]* = $', line)) != 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_file(filename):\n", " articles = []\n", " with open(filename, encoding='utf8') as f:\n", " lines = f.readlines()\n", " current_article = ''\n", " for i,line in enumerate(lines):\n", " current_article += line\n", " if i < len(lines)-2 and lines[i+1] == ' \\n' and istitle(lines[i+2]):\n", " current_article = current_article.replace('', UNK)\n", " articles.append(current_article)\n", " current_article = ''\n", " current_article = current_article.replace('', UNK)\n", " articles.append(current_article)\n", " return np.array(articles)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train = read_file(path/'train.txt')\n", "valid = read_file(path/'valid.txt')\n", "test = read_file(path/'test.txt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(train), len(valid), len(test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_texts = np.concatenate([valid, train, test])\n", "df = pd.DataFrame({'texts':all_texts})\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "del train\n", "del valid\n", "del text" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (TextList.from_df(df, path, cols='texts')\n", " .split_by_idx(range(0,60))\n", " .label_for_lm()\n", " .databunch(bs=100, bptt=70))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = load_data(path)\n", "data.show_batch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = language_model_learner(data, AWD_LSTM, drop_mult=0.1, pretrained=False, clip=0.1, \n", " metrics=[accuracy, Perplexity()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(10, 5e-3, moms=(0.8,0.7), div_factor=10, wd=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('lstm', with_opt=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data.vocab.save(path/'vocab.pkl')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }