{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai import * # Quick accesss to most common functionality\n", "from fastai.text import * # Quick accesss to NLP functionality\n", "from fastai.docs import * # Access to example data provided with fastai" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An example of creating a language model and then transfering to a classifier." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('../data/imdb_sample')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "untar_data(IMDB_PATH)\n", "IMDB_PATH" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Open and view the independent and dependent variables:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
00Un-bleeping-believable! Meg Ryan doesn't even ...
11This is a extremely well-made film. The acting...
20Every once in a long while a movie will come a...
31Name just says it all. I watched this movie wi...
40This movie succeeds at being one of the most u...
\n", "
" ], "text/plain": [ " 0 1\n", "0 0 Un-bleeping-believable! Meg Ryan doesn't even ...\n", "1 1 This is a extremely well-made film. The acting...\n", "2 0 Every once in a long while a movie will come a...\n", "3 1 Name just says it all. I watched this movie wi...\n", "4 0 This movie succeeds at being one of the most u..." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(IMDB_PATH/'train.csv', header=None)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('negative', 'positive')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classes = read_classes(IMDB_PATH/'classes.txt')\n", "classes[0], classes[1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a `DataBunch` for each of the language model and the classifier:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tokenizing train.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Numericalizing train.\n", "Tokenizing valid.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Numericalizing valid.\n" ] } ], "source": [ "data_lm = text_data_from_csv(Path(IMDB_PATH), data_func=lm_data)\n", "data_clas = text_data_from_csv(Path(IMDB_PATH), data_func=classifier_data, vocab=data_lm.train_ds.vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[fast.ai](http://www.fast.ai/) has a pre-trained English model available that we can download." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ef712593ca4c437b848ff6f02ac401d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=221972701), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58572e2d9cf341a988ce1e92bf477118", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=1027972), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "download_wt103_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll fine-tune the language model:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value='0.00% [0/2 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 23:49\n", "epoch train loss valid loss accuracy\n", "0 4.871409 4.130483 0.251005 (12:35)\n", "1 4.607700 4.064432 0.257633 (11:13)\n", "\n" ] } ], "source": [ "learn = RNNLearner.language_model(data_lm, pretrained_fnames=['lstm_wt103', 'itos_wt103'])\n", "learn.unfreeze()\n", "learn.fit(2, slice(1e-4,1e-2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save our language model's encoder:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "learn.save_encoder('enc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fine tune it to create a classifier:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=3), HTML(value='0.00% [0/3 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 14:47\n", "epoch train loss valid loss accuracy\n", "0 0.686279 0.670379 0.690000 (04:45)\n", "1 0.657580 0.624417 0.710000 (05:01)\n", "2 0.648344 0.578787 0.705000 (05:00)\n", "\n" ] } ], "source": [ "learn = RNNLearner.classifier(data_clas)\n", "learn.load_encoder('enc')\n", "learn.fit(3, 1e-3)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }