{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import * # Quick access to NLP functionality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An example of creating a language model and then transfering to a classifier." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/imdb_sample')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.IMDB_SAMPLE)\n", "path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Open and view the independent and dependent variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012
0labeltextis_valid
1negativeUn-bleeping-believable! Meg Ryan doesn't even ...False
2positiveThis is a extremely well-made film. The acting...False
3negativeEvery once in a long while a movie will come a...False
4positiveName just says it all. I watched this movie wi...False
\n", "
" ], "text/plain": [ " 0 1 2\n", "0 label text is_valid\n", "1 negative Un-bleeping-believable! Meg Ryan doesn't even ... False\n", "2 positive This is a extremely well-made film. The acting... False\n", "3 negative Every once in a long while a movie will come a... False\n", "4 positive Name just says it all. I watched this movie wi... False" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path/'texts.csv', header=None)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a `DataBunch` for each of the language model and the classifier:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n", "data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', vocab=data_lm.train_ds.vocab, bs=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll fine-tune the language model. [fast.ai](http://www.fast.ai/) has a pre-trained English model available that we can download, we just have to specify it like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "moms = (0.8,0.7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:17

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
14.6396603.9142690.293896
24.2834203.7236000.302778
34.0325263.6894890.304384
43.8579303.6810900.304303
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = language_model_learner(data_lm, AWD_LSTM)\n", "learn.unfreeze()\n", "learn.fit_one_cycle(4, slice(1e-2), moms=moms)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save our language model's encoder:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save_encoder('enc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fine tune it to create a classifier:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:22

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.6683170.6043980.716418
20.6437910.5720270.701493
30.6229350.5628830.686567
40.6146690.5296850.736318
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = text_classifier_learner(data_clas, AWD_LSTM)\n", "learn.load_encoder('enc')\n", "learn.fit_one_cycle(4, moms=moms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 01:32

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.5889010.5452560.711443
20.6086160.4907640.781095
30.5989890.5728830.701493
40.5704600.4858500.776119
50.5485490.5051900.761194
60.5620360.4882970.771144
70.5454670.4818130.805970
80.5478700.4913840.766169
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(8, slice(1e-5,1e-3), moms=moms)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }