{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import * # Quick access to NLP functionality"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example of creating a language model and then transfering to a classifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('/home/sgugger/.fastai/data/imdb_sample')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Open and view the independent and dependent variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" label | \n",
" text | \n",
" is_valid | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" negative | \n",
" Un-bleeping-believable! Meg Ryan doesn't even ... | \n",
" False | \n",
"
\n",
" \n",
" 1 | \n",
" positive | \n",
" This is a extremely well-made film. The acting... | \n",
" False | \n",
"
\n",
" \n",
" 2 | \n",
" negative | \n",
" Every once in a long while a movie will come a... | \n",
" False | \n",
"
\n",
" \n",
" 3 | \n",
" positive | \n",
" Name just says it all. I watched this movie wi... | \n",
" False | \n",
"
\n",
" \n",
" 4 | \n",
" negative | \n",
" This movie succeeds at being one of the most u... | \n",
" False | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" label text is_valid\n",
"0 negative Un-bleeping-believable! Meg Ryan doesn't even ... False\n",
"1 positive This is a extremely well-made film. The acting... False\n",
"2 negative Every once in a long while a movie will come a... False\n",
"3 positive Name just says it all. I watched this movie wi... False\n",
"4 negative This movie succeeds at being one of the most u... False"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'texts.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a `DataBunch` for each of the language model and the classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n",
"data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', vocab=data_lm.train_ds.vocab, bs=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll fine-tune the language model. [fast.ai](http://www.fast.ai/) has a pre-trained English model available that we can download, we just have to specify it like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"moms = (0.8,0.7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 4.414052 | \n",
" 3.939605 | \n",
" 0.279167 | \n",
" 00:05 | \n",
"
\n",
" \n",
" 1 | \n",
" 4.152833 | \n",
" 3.875656 | \n",
" 0.284345 | \n",
" 00:05 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.832567 | \n",
" 3.848873 | \n",
" 0.286280 | \n",
" 00:05 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.561787 | \n",
" 3.856220 | \n",
" 0.286399 | \n",
" 00:05 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = language_model_learner(data_lm, AWD_LSTM)\n",
"learn.unfreeze()\n",
"learn.fit_one_cycle(4, slice(1e-2), moms=moms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save our language model's encoder:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.save_encoder('enc')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fine tune it to create a classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.659827 | \n",
" 0.600592 | \n",
" 0.766169 | \n",
" 00:04 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.599001 | \n",
" 0.520201 | \n",
" 0.756219 | \n",
" 00:05 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.564309 | \n",
" 0.494556 | \n",
" 0.796020 | \n",
" 00:04 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.520831 | \n",
" 0.495697 | \n",
" 0.776119 | \n",
" 00:04 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = text_classifier_learner(data_clas, AWD_LSTM)\n",
"learn.load_encoder('enc')\n",
"learn.fit_one_cycle(4, moms=moms)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.save('stage1-clas')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.470689 | \n",
" 0.488138 | \n",
" 0.786070 | \n",
" 00:08 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.455899 | \n",
" 0.468737 | \n",
" 0.786070 | \n",
" 00:07 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.474349 | \n",
" 0.498394 | \n",
" 0.771144 | \n",
" 00:08 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.466920 | \n",
" 0.477338 | \n",
" 0.766169 | \n",
" 00:08 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.459592 | \n",
" 0.462194 | \n",
" 0.805970 | \n",
" 00:08 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.431064 | \n",
" 0.472223 | \n",
" 0.786070 | \n",
" 00:08 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.427589 | \n",
" 0.466315 | \n",
" 0.796020 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.417917 | \n",
" 0.461701 | \n",
" 0.786070 | \n",
" 00:08 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.unfreeze()\n",
"learn.fit_one_cycle(8, slice(1e-5,1e-3), moms=moms)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Category tensor(1), tensor(1), tensor([0.0666, 0.9334]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.predict(\"I really liked this movie!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}