{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Vietnamese ULMFiT from scratch (backwards)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai import *\n", "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "bs=128" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "data_path = Config.data_path()\n", "lang = 'vi'\n", "name = f'{lang}wiki'\n", "path = data_path/name\n", "dest = path/'docs'\n", "lm_fns = [f'{lang}_wt_bwd', f'{lang}_wt_vocab_bwd']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vietnamese wikipedia model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "data = (TextList.from_folder(dest)\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm()\n", " .databunch(bs=bs, num_workers=1, backwards=True))\n", "\n", "data.save(f'{lang}_databunch_bwd')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:493: SourceChangeWarning: source code of class 'torch.nn.modules.loss.CrossEntropyLoss' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", " warnings.warn(msg, SourceChangeWarning)\n" ] } ], "source": [ "data = load_data(dest, f'{lang}_databunch_bwd', bs=bs, backwards=True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false }, "outputs": [], "source": [ "learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False).to_fp16()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "lr = 3e-3\n", "lr *= bs/48 # Scale learning rate by batch size" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.4458493.4245790.40132732:56
13.4208653.3839940.40284133:31
23.3746943.3306340.40780033:26
33.2731973.2571080.41604732:54
43.2230443.2006490.42269532:56
53.1343573.1328590.43072531:35
63.1356373.0570300.43973731:41
73.0804612.9923230.44793931:45
83.0750362.9436830.45449431:39
92.9479972.9292580.45650031:46
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(10, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "mdl_path = path/'models'\n", "mdl_path.mkdir(exist_ok=True)\n", "learn.to_fp32().save(mdl_path/lm_fns[0], with_opt=False)\n", "learn.data.vocab.save(mdl_path/(lm_fns[1] + '.pkl'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vietnamese sentiment analysis" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Language model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hidden": true }, "outputs": [], "source": [ "train_df = pd.read_csv(path/'train.csv')\n", "train_df.loc[pd.isna(train_df.comment),'comment']='NA'\n", "\n", "test_df = pd.read_csv(path/'test.csv')\n", "test_df.loc[pd.isna(test_df.comment),'comment']='NA'\n", "test_df['label'] = 0\n", "\n", "df = pd.concat([train_df,test_df])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data_lm = (TextList.from_df(df, path, cols='comment')\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1, backwards=True))\n", "\n", "learn_lm = language_model_learner(data_lm, AWD_LSTM, config={**awd_lstm_lm_config, 'n_hid': 1152},\n", " pretrained_fnames=lm_fns, drop_mult=1.0)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hidden": true }, "outputs": [], "source": [ "lr = 1e-3\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
04.7970524.0259010.32332600:07
14.2759753.9144500.33371900:06
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.fit_one_cycle(2, lr*10, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.9967703.8094890.34605200:09
13.8569593.6649190.36323900:09
23.7261433.5843030.36968500:09
33.6085693.5313900.37530700:09
43.5142653.5008260.37970100:09
53.4462923.4869310.38085900:09
63.3925423.4797320.38252000:09
73.3575023.4789300.38252000:09
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.unfreeze()\n", "learn_lm.fit_one_cycle(8, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "hidden": true }, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned_bwd')\n", "learn_lm.save_encoder(f'{lang}fine_tuned_enc_bwd')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Classifier" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data_clas = (TextList.from_df(train_df, path, vocab=data_lm.vocab, cols='comment')\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_from_df(cols='label')\n", " .databunch(bs=bs, num_workers=1, backwards=True))\n", "\n", "data_clas.save(f'{lang}_textlist_class_bwd')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "data_clas = load_data(path, f'{lang}_textlist_class_bwd', bs=bs, num_workers=1, backwards=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score\n", "\n", "@np_func\n", "def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load_encoder(f'{lang}fine_tuned_enc_bwd')\n", "learn_c.freeze()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "lr=2e-2\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.3693000.3637690.8345770.82609800:03
10.3281920.2789860.8743780.85174700:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.3378750.3061320.8768660.86010700:03
10.2769820.2372600.9060950.88642700:03
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-2)\n", "learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.2922970.2523930.8961440.87791600:04
10.2552840.2136550.9123130.89255100:04
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-3)\n", "learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.1673760.2666330.9048510.88538600:04
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.unfreeze()\n", "learn_c.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_bwd')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }