{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vietnamese ULMFiT from scratch (backwards)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"from fastai import *\n",
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"bs=128"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data_path = Config.data_path()\n",
"lang = 'vi'\n",
"name = f'{lang}wiki'\n",
"path = data_path/name\n",
"dest = path/'docs'\n",
"lm_fns = [f'{lang}_wt_bwd', f'{lang}_wt_vocab_bwd']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Vietnamese wikipedia model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"data = (TextList.from_folder(dest)\n",
" .split_by_rand_pct(0.1, seed=42)\n",
" .label_for_lm()\n",
" .databunch(bs=bs, num_workers=1, backwards=True))\n",
"\n",
"data.save(f'{lang}_databunch_bwd')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:493: SourceChangeWarning: source code of class 'torch.nn.modules.loss.CrossEntropyLoss' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
" warnings.warn(msg, SourceChangeWarning)\n"
]
}
],
"source": [
"data = load_data(dest, f'{lang}_databunch_bwd', bs=bs, backwards=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False).to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"lr = 3e-3\n",
"lr *= bs/48 # Scale learning rate by batch size"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.445849 | \n",
" 3.424579 | \n",
" 0.401327 | \n",
" 32:56 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.420865 | \n",
" 3.383994 | \n",
" 0.402841 | \n",
" 33:31 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.374694 | \n",
" 3.330634 | \n",
" 0.407800 | \n",
" 33:26 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.273197 | \n",
" 3.257108 | \n",
" 0.416047 | \n",
" 32:54 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.223044 | \n",
" 3.200649 | \n",
" 0.422695 | \n",
" 32:56 | \n",
"
\n",
" \n",
" 5 | \n",
" 3.134357 | \n",
" 3.132859 | \n",
" 0.430725 | \n",
" 31:35 | \n",
"
\n",
" \n",
" 6 | \n",
" 3.135637 | \n",
" 3.057030 | \n",
" 0.439737 | \n",
" 31:41 | \n",
"
\n",
" \n",
" 7 | \n",
" 3.080461 | \n",
" 2.992323 | \n",
" 0.447939 | \n",
" 31:45 | \n",
"
\n",
" \n",
" 8 | \n",
" 3.075036 | \n",
" 2.943683 | \n",
" 0.454494 | \n",
" 31:39 | \n",
"
\n",
" \n",
" 9 | \n",
" 2.947997 | \n",
" 2.929258 | \n",
" 0.456500 | \n",
" 31:46 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.unfreeze()\n",
"learn.fit_one_cycle(10, lr, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"mdl_path = path/'models'\n",
"mdl_path.mkdir(exist_ok=True)\n",
"learn.to_fp32().save(mdl_path/lm_fns[0], with_opt=False)\n",
"learn.data.vocab.save(mdl_path/(lm_fns[1] + '.pkl'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Vietnamese sentiment analysis"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"### Language model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"train_df = pd.read_csv(path/'train.csv')\n",
"train_df.loc[pd.isna(train_df.comment),'comment']='NA'\n",
"\n",
"test_df = pd.read_csv(path/'test.csv')\n",
"test_df.loc[pd.isna(test_df.comment),'comment']='NA'\n",
"test_df['label'] = 0\n",
"\n",
"df = pd.concat([train_df,test_df])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data_lm = (TextList.from_df(df, path, cols='comment')\n",
" .split_by_rand_pct(0.1, seed=42)\n",
" .label_for_lm() \n",
" .databunch(bs=bs, num_workers=1, backwards=True))\n",
"\n",
"learn_lm = language_model_learner(data_lm, AWD_LSTM, config={**awd_lstm_lm_config, 'n_hid': 1152},\n",
" pretrained_fnames=lm_fns, drop_mult=1.0)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"lr = 1e-3\n",
"lr *= bs/48"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 4.797052 | \n",
" 4.025901 | \n",
" 0.323326 | \n",
" 00:07 | \n",
"
\n",
" \n",
" 1 | \n",
" 4.275975 | \n",
" 3.914450 | \n",
" 0.333719 | \n",
" 00:06 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_lm.fit_one_cycle(2, lr*10, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"hidden": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 3.996770 | \n",
" 3.809489 | \n",
" 0.346052 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.856959 | \n",
" 3.664919 | \n",
" 0.363239 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.726143 | \n",
" 3.584303 | \n",
" 0.369685 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.608569 | \n",
" 3.531390 | \n",
" 0.375307 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.514265 | \n",
" 3.500826 | \n",
" 0.379701 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 5 | \n",
" 3.446292 | \n",
" 3.486931 | \n",
" 0.380859 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 6 | \n",
" 3.392542 | \n",
" 3.479732 | \n",
" 0.382520 | \n",
" 00:09 | \n",
"
\n",
" \n",
" 7 | \n",
" 3.357502 | \n",
" 3.478930 | \n",
" 0.382520 | \n",
" 00:09 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_lm.unfreeze()\n",
"learn_lm.fit_one_cycle(8, lr, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"hidden": true
},
"outputs": [],
"source": [
"learn_lm.save(f'{lang}fine_tuned_bwd')\n",
"learn_lm.save_encoder(f'{lang}fine_tuned_enc_bwd')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Classifier"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data_clas = (TextList.from_df(train_df, path, vocab=data_lm.vocab, cols='comment')\n",
" .split_by_rand_pct(0.1, seed=42)\n",
" .label_from_df(cols='label')\n",
" .databunch(bs=bs, num_workers=1, backwards=True))\n",
"\n",
"data_clas.save(f'{lang}_textlist_class_bwd')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"data_clas = load_data(path, f'{lang}_textlist_class_bwd', bs=bs, num_workers=1, backwards=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import f1_score\n",
"\n",
"@np_func\n",
"def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n",
"learn_c.load_encoder(f'{lang}fine_tuned_enc_bwd')\n",
"learn_c.freeze()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"lr=2e-2\n",
"lr *= bs/48"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" f1 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.369300 | \n",
" 0.363769 | \n",
" 0.834577 | \n",
" 0.826098 | \n",
" 00:03 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.328192 | \n",
" 0.278986 | \n",
" 0.874378 | \n",
" 0.851747 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" f1 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.337875 | \n",
" 0.306132 | \n",
" 0.876866 | \n",
" 0.860107 | \n",
" 00:03 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.276982 | \n",
" 0.237260 | \n",
" 0.906095 | \n",
" 0.886427 | \n",
" 00:03 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.freeze_to(-2)\n",
"learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" f1 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.292297 | \n",
" 0.252393 | \n",
" 0.896144 | \n",
" 0.877916 | \n",
" 00:04 | \n",
"
\n",
" \n",
" 1 | \n",
" 0.255284 | \n",
" 0.213655 | \n",
" 0.912313 | \n",
" 0.892551 | \n",
" 00:04 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.freeze_to(-3)\n",
"learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" f1 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.167376 | \n",
" 0.266633 | \n",
" 0.904851 | \n",
" 0.885386 | \n",
" 00:04 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn_c.unfreeze()\n",
"learn_c.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"learn_c.save(f'{lang}clas_bwd')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}