{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import * # Quick access to NLP functionality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An example of creating a language model and then transfering to a classifier." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/sgugger/.fastai/data/imdb_sample')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.IMDB_SAMPLE)\n", "path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Open and view the independent and dependent variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextis_valid
0negativeUn-bleeping-believable! Meg Ryan doesn't even ...False
1positiveThis is a extremely well-made film. The acting...False
2negativeEvery once in a long while a movie will come a...False
3positiveName just says it all. I watched this movie wi...False
4negativeThis movie succeeds at being one of the most u...False
\n", "
" ], "text/plain": [ " label text is_valid\n", "0 negative Un-bleeping-believable! Meg Ryan doesn't even ... False\n", "1 positive This is a extremely well-made film. The acting... False\n", "2 negative Every once in a long while a movie will come a... False\n", "3 positive Name just says it all. I watched this movie wi... False\n", "4 negative This movie succeeds at being one of the most u... False" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path/'texts.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a `DataBunch` for each of the language model and the classifier:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n", "data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', vocab=data_lm.train_ds.vocab, bs=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll fine-tune the language model. [fast.ai](http://www.fast.ai/) has a pre-trained English model available that we can download, we just have to specify it like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "moms = (0.8,0.7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
04.4274223.8723530.29017900:04
14.1537383.8068260.29416700:04
23.8351913.7875780.29549100:04
33.5669093.7914150.29669600:04
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = language_model_learner(data_lm, AWD_LSTM)\n", "learn.unfreeze()\n", "learn.fit_one_cycle(4, slice(1e-2), moms=moms)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save our language model's encoder:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save_encoder('enc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fine tune it to create a classifier:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.6766080.5882550.79104500:05
10.6401270.5123410.79602000:05
20.5834520.4528670.79602000:05
30.5505180.4509670.78607000:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = text_classifier_learner(data_clas, AWD_LSTM)\n", "learn.load_encoder('enc')\n", "learn.fit_one_cycle(4, moms=moms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('stage1-clas')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.5023580.4305610.80099500:08
10.4753050.4428210.79602000:08
20.4685990.4271600.80597000:07
30.4623680.3844890.84577100:08
40.4626260.3796670.84577100:07
50.4444050.3805100.83582100:07
60.4223570.3723410.86069600:08
70.4160240.3804860.83084600:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(8, slice(1e-5,1e-3), moms=moms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('stage2-clas')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('stage1-clas');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor([[ 0.0739, 0.0123, 0.0579, ..., 0.0617, 0.0304, 0.0275],\n", " [-0.0487, 0.0562, 0.0039, ..., 0.0523, -0.0193, 0.0294],\n", " [-0.0179, -0.1158, 0.1973, ..., 0.0256, -0.0063, -0.0337],\n", " ...,\n", " [ 0.0111, -0.0447, -0.0007, ..., -0.0460, -0.0016, 0.0070],\n", " [-0.0272, 0.0378, 0.0377, ..., -0.0205, 0.1363, -0.0199],\n", " [-0.0088, -0.0115, -0.0832, ..., -0.0684, 0.1311, -0.0668]],\n", " device='cuda:0', requires_grad=True), Parameter containing:\n", " tensor([[-0.0664, -0.1465, -0.0776, ..., 0.1152, 0.0886, 0.0717],\n", " [ 0.0177, 0.1248, -0.0452, ..., -0.0159, -0.0884, -0.0310],\n", " [ 0.0828, -0.0289, -0.0932, ..., 0.1441, 0.1289, 0.0946],\n", " ...,\n", " [-0.1123, -0.0756, 0.3082, ..., -0.0644, -0.0201, 0.0431],\n", " [ 0.0530, 0.0738, 0.0781, ..., 0.0096, 0.2213, -0.0149],\n", " [ 0.2115, -0.0221, 0.0563, ..., -0.2186, 0.0302, -0.0570]],\n", " device='cuda:0', requires_grad=True), Parameter containing:\n", " tensor([0.2707, 0.1091, 0.1264, ..., 0.2405, 0.1249, 0.3700], device='cuda:0',\n", " requires_grad=True), Parameter containing:\n", " tensor([0.2707, 0.1091, 0.1264, ..., 0.2405, 0.1249, 0.3700], device='cuda:0',\n", " requires_grad=True)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(learn.model[0].module.rnns[-1].parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('stage2-clas');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('weight_hh_l0_raw', Parameter containing:\n", " tensor([[ 7.0162e-02, 1.2104e-02, 5.8028e-02, ..., 6.0046e-02,\n", " 3.0242e-02, 2.7227e-02],\n", " [-4.6466e-02, 5.5549e-02, 5.8688e-03, ..., 5.3150e-02,\n", " -1.5318e-02, 2.5991e-02],\n", " [-1.2340e-02, -1.1872e-01, 1.9423e-01, ..., 2.8058e-02,\n", " -1.0008e-02, -3.6459e-02],\n", " ...,\n", " [ 1.4048e-02, -4.4230e-02, -1.3452e-03, ..., -4.7389e-02,\n", " -2.4187e-05, 5.1805e-03],\n", " [-2.0523e-02, 3.7637e-02, 4.2203e-02, ..., -2.0026e-02,\n", " 1.4004e-01, -1.9072e-02],\n", " [-8.9586e-03, -6.4311e-03, -8.2364e-02, ..., -6.6814e-02,\n", " 1.3448e-01, -6.9005e-02]], device='cuda:0', requires_grad=True)),\n", " ('module.weight_ih_l0', Parameter containing:\n", " tensor([[-0.0669, -0.1461, -0.0792, ..., 0.1177, 0.0882, 0.0660],\n", " [ 0.0178, 0.1236, -0.0439, ..., -0.0163, -0.0862, -0.0382],\n", " [ 0.0770, -0.0259, -0.0918, ..., 0.1413, 0.1248, 0.0883],\n", " ...,\n", " [-0.1126, -0.0728, 0.3100, ..., -0.0681, -0.0215, 0.0429],\n", " [ 0.0536, 0.0708, 0.0750, ..., 0.0100, 0.2192, -0.0167],\n", " [ 0.2099, -0.0273, 0.0531, ..., -0.2179, 0.0290, -0.0598]],\n", " device='cuda:0', requires_grad=True)),\n", " ('module.bias_ih_l0', Parameter containing:\n", " tensor([0.2622, 0.0981, 0.1163, ..., 0.2334, 0.1223, 0.3612], device='cuda:0',\n", " requires_grad=True)),\n", " ('module.bias_hh_l0', Parameter containing:\n", " tensor([0.2622, 0.0981, 0.1163, ..., 0.2334, 0.1223, 0.3612], device='cuda:0',\n", " requires_grad=True))]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(learn.model[0].module.rnns[-1].named_parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }