{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Language Modeling & Sentiment Analysis of IMDB movie reviews" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai import *\n", "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# bs=48\n", "bs=128" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.IMDB)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Language model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data_lm = (TextList.from_folder(path)\n", " .filter_by_folder(include=['train', 'test', 'unsup']) \n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))\n", "\n", "len(data_lm.vocab.itos),len(data_lm.train_ds)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "data_lm.save('lm_databunch')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "data_lm = load_data(path, 'lm_databunch', bs=bs)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [], "source": [ "learn_lm = language_model_learner(data_lm, AWD_LSTM, drop_mult=1.).to_fp16()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "lr = 1e-2\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
04.6040464.1890020.27826518:36
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.fit_one_cycle(1, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "learn_lm.unfreeze()\n", "learn_lm.fit_one_cycle(10, lr/10, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "learn_lm.save('fine_tuned_10')\n", "learn_lm.save_encoder('fine_tuned_enc_10')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classifier" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "data_clas = (TextList.from_folder(path, vocab=data_lm.vocab)\n", " .split_by_folder(valid='test')\n", " .label_from_folder(classes=['neg', 'pos'])\n", " .databunch(bs=bs, num_workers=1))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "data_clas.save('imdb_textlist_class')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "data_clas = load_data(path, 'imdb_textlist_class', bs=bs, num_workers=1)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5).to_fp16()\n", "learn_c.load_encoder('fine_tuned_enc_10')\n", "learn_c.freeze()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "lr=2e-2\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.2415230.1901280.92660001:16
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(1, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "learn_c.save('1')" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.2048180.1616750.93864002:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-2)\n", "learn_c.fit_one_cycle(1, slice(lr/(2.6**4),lr), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "learn_c.save('2nd')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1794510.1440470.94584002:56
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-3)\n", "learn_c.fit_one_cycle(1, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "learn_c.save('3rd')" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1200630.1457010.94700003:24
10.0873030.1529430.94808003:09
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.unfreeze()\n", "learn_c.fit_one_cycle(2, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "learn_c.save('clas')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }