{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import * # Quick access to NLP functionality"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example of creating a language model and then transfering to a classifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('/home/sgugger/.fastai/data/imdb_sample')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Open and view the independent and dependent variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" label | \n",
" text | \n",
" is_valid | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" negative | \n",
" Un-bleeping-believable! Meg Ryan doesn't even ... | \n",
" False | \n",
"
\n",
" \n",
" | 1 | \n",
" positive | \n",
" This is a extremely well-made film. The acting... | \n",
" False | \n",
"
\n",
" \n",
" | 2 | \n",
" negative | \n",
" Every once in a long while a movie will come a... | \n",
" False | \n",
"
\n",
" \n",
" | 3 | \n",
" positive | \n",
" Name just says it all. I watched this movie wi... | \n",
" False | \n",
"
\n",
" \n",
" | 4 | \n",
" negative | \n",
" This movie succeeds at being one of the most u... | \n",
" False | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" label text is_valid\n",
"0 negative Un-bleeping-believable! Meg Ryan doesn't even ... False\n",
"1 positive This is a extremely well-made film. The acting... False\n",
"2 negative Every once in a long while a movie will come a... False\n",
"3 positive Name just says it all. I watched this movie wi... False\n",
"4 negative This movie succeeds at being one of the most u... False"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'texts.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a `DataBunch` for each of the language model and the classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n",
"data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', vocab=data_lm.train_ds.vocab, bs=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll fine-tune the language model. [fast.ai](http://www.fast.ai/) has a pre-trained English model available that we can download, we just have to specify it like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"moms = (0.8,0.7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 4.427422 | \n",
" 3.872353 | \n",
" 0.290179 | \n",
" 00:04 | \n",
"
\n",
" \n",
" | 1 | \n",
" 4.153738 | \n",
" 3.806826 | \n",
" 0.294167 | \n",
" 00:04 | \n",
"
\n",
" \n",
" | 2 | \n",
" 3.835191 | \n",
" 3.787578 | \n",
" 0.295491 | \n",
" 00:04 | \n",
"
\n",
" \n",
" | 3 | \n",
" 3.566909 | \n",
" 3.791415 | \n",
" 0.296696 | \n",
" 00:04 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = language_model_learner(data_lm, AWD_LSTM)\n",
"learn.unfreeze()\n",
"learn.fit_one_cycle(4, slice(1e-2), moms=moms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Save our language model's encoder:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.save_encoder('enc')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fine tune it to create a classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.676608 | \n",
" 0.588255 | \n",
" 0.791045 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.640127 | \n",
" 0.512341 | \n",
" 0.796020 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.583452 | \n",
" 0.452867 | \n",
" 0.796020 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.550518 | \n",
" 0.450967 | \n",
" 0.786070 | \n",
" 00:05 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = text_classifier_learner(data_clas, AWD_LSTM)\n",
"learn.load_encoder('enc')\n",
"learn.fit_one_cycle(4, moms=moms)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.save('stage1-clas')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.502358 | \n",
" 0.430561 | \n",
" 0.800995 | \n",
" 00:08 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.475305 | \n",
" 0.442821 | \n",
" 0.796020 | \n",
" 00:08 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.468599 | \n",
" 0.427160 | \n",
" 0.805970 | \n",
" 00:07 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.462368 | \n",
" 0.384489 | \n",
" 0.845771 | \n",
" 00:08 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.462626 | \n",
" 0.379667 | \n",
" 0.845771 | \n",
" 00:07 | \n",
"
\n",
" \n",
" | 5 | \n",
" 0.444405 | \n",
" 0.380510 | \n",
" 0.835821 | \n",
" 00:07 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.422357 | \n",
" 0.372341 | \n",
" 0.860696 | \n",
" 00:08 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.416024 | \n",
" 0.380486 | \n",
" 0.830846 | \n",
" 00:07 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.unfreeze()\n",
"learn.fit_one_cycle(8, slice(1e-5,1e-3), moms=moms)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.save('stage2-clas')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.load('stage1-clas');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Parameter containing:\n",
" tensor([[ 0.0739, 0.0123, 0.0579, ..., 0.0617, 0.0304, 0.0275],\n",
" [-0.0487, 0.0562, 0.0039, ..., 0.0523, -0.0193, 0.0294],\n",
" [-0.0179, -0.1158, 0.1973, ..., 0.0256, -0.0063, -0.0337],\n",
" ...,\n",
" [ 0.0111, -0.0447, -0.0007, ..., -0.0460, -0.0016, 0.0070],\n",
" [-0.0272, 0.0378, 0.0377, ..., -0.0205, 0.1363, -0.0199],\n",
" [-0.0088, -0.0115, -0.0832, ..., -0.0684, 0.1311, -0.0668]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
" tensor([[-0.0664, -0.1465, -0.0776, ..., 0.1152, 0.0886, 0.0717],\n",
" [ 0.0177, 0.1248, -0.0452, ..., -0.0159, -0.0884, -0.0310],\n",
" [ 0.0828, -0.0289, -0.0932, ..., 0.1441, 0.1289, 0.0946],\n",
" ...,\n",
" [-0.1123, -0.0756, 0.3082, ..., -0.0644, -0.0201, 0.0431],\n",
" [ 0.0530, 0.0738, 0.0781, ..., 0.0096, 0.2213, -0.0149],\n",
" [ 0.2115, -0.0221, 0.0563, ..., -0.2186, 0.0302, -0.0570]],\n",
" device='cuda:0', requires_grad=True), Parameter containing:\n",
" tensor([0.2707, 0.1091, 0.1264, ..., 0.2405, 0.1249, 0.3700], device='cuda:0',\n",
" requires_grad=True), Parameter containing:\n",
" tensor([0.2707, 0.1091, 0.1264, ..., 0.2405, 0.1249, 0.3700], device='cuda:0',\n",
" requires_grad=True)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(learn.model[0].module.rnns[-1].parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn.load('stage2-clas');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('weight_hh_l0_raw', Parameter containing:\n",
" tensor([[ 7.0162e-02, 1.2104e-02, 5.8028e-02, ..., 6.0046e-02,\n",
" 3.0242e-02, 2.7227e-02],\n",
" [-4.6466e-02, 5.5549e-02, 5.8688e-03, ..., 5.3150e-02,\n",
" -1.5318e-02, 2.5991e-02],\n",
" [-1.2340e-02, -1.1872e-01, 1.9423e-01, ..., 2.8058e-02,\n",
" -1.0008e-02, -3.6459e-02],\n",
" ...,\n",
" [ 1.4048e-02, -4.4230e-02, -1.3452e-03, ..., -4.7389e-02,\n",
" -2.4187e-05, 5.1805e-03],\n",
" [-2.0523e-02, 3.7637e-02, 4.2203e-02, ..., -2.0026e-02,\n",
" 1.4004e-01, -1.9072e-02],\n",
" [-8.9586e-03, -6.4311e-03, -8.2364e-02, ..., -6.6814e-02,\n",
" 1.3448e-01, -6.9005e-02]], device='cuda:0', requires_grad=True)),\n",
" ('module.weight_ih_l0', Parameter containing:\n",
" tensor([[-0.0669, -0.1461, -0.0792, ..., 0.1177, 0.0882, 0.0660],\n",
" [ 0.0178, 0.1236, -0.0439, ..., -0.0163, -0.0862, -0.0382],\n",
" [ 0.0770, -0.0259, -0.0918, ..., 0.1413, 0.1248, 0.0883],\n",
" ...,\n",
" [-0.1126, -0.0728, 0.3100, ..., -0.0681, -0.0215, 0.0429],\n",
" [ 0.0536, 0.0708, 0.0750, ..., 0.0100, 0.2192, -0.0167],\n",
" [ 0.2099, -0.0273, 0.0531, ..., -0.2179, 0.0290, -0.0598]],\n",
" device='cuda:0', requires_grad=True)),\n",
" ('module.bias_ih_l0', Parameter containing:\n",
" tensor([0.2622, 0.0981, 0.1163, ..., 0.2334, 0.1223, 0.3612], device='cuda:0',\n",
" requires_grad=True)),\n",
" ('module.bias_hh_l0', Parameter containing:\n",
" tensor([0.2622, 0.0981, 0.1163, ..., 0.2334, 0.1223, 0.3612], device='cuda:0',\n",
" requires_grad=True))]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(learn.model[0].module.rnns[-1].named_parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}