{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Language Modeling & Sentiment Analysis of IMDB movie reviews"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"from fastai import *\n",
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# bs=48\n",
"bs=128"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMDB)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Language model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data_lm = (TextList.from_folder(path)\n",
" .filter_by_folder(include=['train', 'test', 'unsup']) \n",
" .split_by_rand_pct(0.1, seed=42)\n",
" .label_for_lm() \n",
" .databunch(bs=bs, num_workers=1))\n",
"\n",
"len(data_lm.vocab.itos),len(data_lm.train_ds)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"data_lm.save('lm_databunch')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data_lm = load_data(path, 'lm_databunch', bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"learn_lm = language_model_learner(data_lm, AWD_LSTM, drop_mult=1.).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"lr = 1e-2\n",
"lr *= bs/48"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 4.604046 | \n",
" 4.189002 | \n",
" 0.278265 | \n",
" 18:36 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_lm.fit_one_cycle(1, lr, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"learn_lm.unfreeze()\n",
"learn_lm.fit_one_cycle(10, lr/10, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"learn_lm.save('fine_tuned_10')\n",
"learn_lm.save_encoder('fine_tuned_enc_10')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classifier"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"data_clas = (TextList.from_folder(path, vocab=data_lm.vocab)\n",
" .split_by_folder(valid='test')\n",
" .label_from_folder(classes=['neg', 'pos'])\n",
" .databunch(bs=bs, num_workers=1))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data_clas.save('imdb_textlist_class')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data_clas = load_data(path, 'imdb_textlist_class', bs=bs, num_workers=1)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5).to_fp16()\n",
"learn_c.load_encoder('fine_tuned_enc_10')\n",
"learn_c.freeze()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"lr=2e-2\n",
"lr *= bs/48"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.241523 | \n",
" 0.190128 | \n",
" 0.926600 | \n",
" 01:16 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.fit_one_cycle(1, lr, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"learn_c.save('1')"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.204818 | \n",
" 0.161675 | \n",
" 0.938640 | \n",
" 02:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.freeze_to(-2)\n",
"learn_c.fit_one_cycle(1, slice(lr/(2.6**4),lr), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"learn_c.save('2nd')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.179451 | \n",
" 0.144047 | \n",
" 0.945840 | \n",
" 02:56 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.freeze_to(-3)\n",
"learn_c.fit_one_cycle(1, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"learn_c.save('3rd')"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.120063 | \n",
" 0.145701 | \n",
" 0.947000 | \n",
" 03:24 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.087303 | \n",
" 0.152943 | \n",
" 0.948080 | \n",
" 03:09 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.unfreeze()\n",
"learn_c.fit_one_cycle(2, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"learn_c.save('clas')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}