{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Vietnamese ULMFiT from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai import *\n", "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# bs=48\n", "# bs=24\n", "bs=128" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "torch.cuda.set_device(2)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "data_path = Config.data_path()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will create a `viwiki` folder, containing a `viwiki` text file with the wikipedia contents. (For other languages, replace `vi` with the appropriate code from the [list of wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "lang = 'vi'\n", "# lang = 'zh'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "name = f'{lang}wiki'\n", "path = data_path/name\n", "path.mkdir(exist_ok=True, parents=True)\n", "lm_fns = [f'{lang}_wt', f'{lang}_wt_vocab']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vietnamese wikipedia model" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Download data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "from nlputils import split_wiki,get_wiki" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "hidden": true }, "outputs": [], "source": [ "get_wiki(path,lang)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/jhoward/data/zhwiki/docs'),\n", " PosixPath('/home/jhoward/data/zhwiki/zhwiki-latest-pages-articles.xml.bz2'),\n", " PosixPath('/home/jhoward/data/zhwiki/zh.cnf'),\n", " PosixPath('/home/jhoward/data/zhwiki/log'),\n", " PosixPath('/home/jhoward/data/zhwiki/zhwiki'),\n", " PosixPath('/home/jhoward/data/zhwiki/zhwiki-latest-pages-articles.xml'),\n", " PosixPath('/home/jhoward/data/zhwiki/wikiextractor')]" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Tiếng Việt\r\n", "\r\n", "Tiếng Việt, còn gọi tiếng Việt Nam hay Việt ngữ, là ngôn ngữ của người Việt (người Kinh) và là ngôn ngữ chính thức tại Việt Nam. Đây là tiếng mẹ đẻ của khoảng 85% dân cư Việt Nam, cùng với hơn 4 triệu Việt kiều. Tiếng Việt còn là ngôn ngữ thứ hai của các dân tộc thiểu số tại Việt Nam. Mặc dù tiếng Việt có một số từ vựng vay mượn từ tiếng Hán và trước đây dùng chữ Nôm – một hệ chữ viết dựa trên chữ Hán – để viết nhưng tiếng Việt được coi là một trong số các ngôn ngữ thuộc ngữ hệ Nam Á có số người nói nhiều nhất (nhiều hơn một số lần so với các ngôn ngữ khác cùng hệ cộng lại). Ngày nay, tiếng Việt dùng bảng chữ cái Latinh, gọi là chữ Quốc ngữ, cùng các dấu thanh để viết.\r\n" ] } ], "source": [ "!head -n4 {path}/{name}" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "This function splits the single wikipedia file into a separate file per article. This is often easier to work with." ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "hidden": true }, "outputs": [], "source": [ "dest = split_wiki(path,lang)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/jhoward/data/viwiki/docs/Luis Suárez.txt'),\n", " PosixPath('/home/jhoward/data/viwiki/docs/Vitas.txt'),\n", " PosixPath('/home/jhoward/data/viwiki/docs/Chùa Hà.txt'),\n", " PosixPath('/home/jhoward/data/viwiki/docs/Đại Phái bộ Sứ thần.txt'),\n", " PosixPath('/home/jhoward/data/viwiki/docs/2 Broke Girls.txt')]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dest.ls()[:5]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hidden": true }, "outputs": [], "source": [ "# Use this to convert Chinese traditional to simplified characters\n", "# ls *.txt | parallel -I% opencc -i % -o ../zhsdocs/% -c t2s.json" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Create pretrained model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hidden": true }, "outputs": [], "source": [ "data = (TextList.from_folder(dest)\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))\n", "\n", "data.save(f'{lang}_databunch')\n", "len(data.vocab.itos),len(data.train_ds)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hidden": true }, "outputs": [], "source": [ "data = load_data(path, f'{lang}_databunch', bs=bs)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hidden": true, "scrolled": false }, "outputs": [], "source": [ "learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False).to_fp16()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hidden": true }, "outputs": [], "source": [ "lr = 1e-2\n", "lr *= bs/48 # Scale learning rate by batch size" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.4361133.4914340.36692528:52
13.4412403.5441180.36132628:33
23.5717663.5569320.35843828:31
33.5105403.5192430.36227828:27
43.4476393.4493200.36940428:29
53.4122843.4063760.37502228:20
63.2867543.2553090.39187428:19
73.1724973.1285220.40680328:37
83.1268673.0252490.41988228:36
93.1287932.9910770.42462228:39
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(10, lr, moms=(0.8,0.7))" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Save the pretrained model and vocab:" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "hidden": true }, "outputs": [], "source": [ "mdl_path = path/'models'\n", "mdl_path.mkdir(exist_ok=True)\n", "learn.to_fp32().save(mdl_path/lm_fns[0], with_opt=False)\n", "learn.data.vocab.save(mdl_path/(lm_fns[1] + '.pkl'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vietnamese sentiment analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Language model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- [Data](https://github.com/ngxbac/aivivn_phanloaisacthaibinhluan/tree/master/data)\n", "- [Competition details](https://www.aivivn.com/contests/1)\n", "- Top 3 f1 scores: 0.900, 0.897, 0.897" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idcommentlabel
0train_000000Dung dc sp tot cam on \\nshop Đóng gói sản phẩm...0
1train_000001Chất lượng sản phẩm tuyệt vời . Son mịn nhưng...0
2train_000002Chất lượng sản phẩm tuyệt vời nhưng k có hộp ...0
3train_000003:(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ...1
4train_000004Lần trước mình mua áo gió màu hồng rất ok mà đ...1
\n", "
" ], "text/plain": [ " id comment label\n", "0 train_000000 Dung dc sp tot cam on \\nshop Đóng gói sản phẩm... 0\n", "1 train_000001 Chất lượng sản phẩm tuyệt vời . Son mịn nhưng... 0\n", "2 train_000002 Chất lượng sản phẩm tuyệt vời nhưng k có hộp ... 0\n", "3 train_000003 :(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ... 1\n", "4 train_000004 Lần trước mình mua áo gió màu hồng rất ok mà đ... 1" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df = pd.read_csv(path/'train.csv')\n", "train_df.loc[pd.isna(train_df.comment),'comment']='NA'\n", "train_df.head()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idcomment
0test_000000Chưa dùng thử nên chưa biết
1test_000001Không đáng tiềnVì ngay đợt sale nên mới mua n...
2test_000002Cám ơn shop. Đóng gói sản phẩm rất đẹp và chắc...
3test_000003Vải đẹp.phom oki luôn.quá ưng
4test_000004Chuẩn hàng đóng gói đẹp
\n", "
" ], "text/plain": [ " id comment\n", "0 test_000000 Chưa dùng thử nên chưa biết\n", "1 test_000001 Không đáng tiềnVì ngay đợt sale nên mới mua n...\n", "2 test_000002 Cám ơn shop. Đóng gói sản phẩm rất đẹp và chắc...\n", "3 test_000003 Vải đẹp.phom oki luôn.quá ưng\n", "4 test_000004 Chuẩn hàng đóng gói đẹp" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_df = pd.read_csv(path/'test.csv')\n", "test_df.loc[pd.isna(test_df.comment),'comment']='NA'\n", "test_df.head()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "df = pd.concat([train_df,test_df], sort=False)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "data_lm = (TextList.from_df(df, path, cols='comment')\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "learn_lm = language_model_learner(data_lm, AWD_LSTM, pretrained_fnames=lm_fns, drop_mult=1.0)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "lr = 1e-3\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
04.9750804.1385850.31777300:07
14.4086354.0254890.32642300:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.fit_one_cycle(2, lr*10, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
04.1421143.9282780.33623000:09
14.0108353.7935830.34997200:09
23.8736173.6947020.35724000:09
33.7613773.6321860.36464800:09
43.6790173.5956010.36696400:09
53.6145483.5763860.36922400:09
63.5758953.5674960.37028500:09
73.5602783.5665250.37017300:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.unfreeze()\n", "learn_lm.fit_one_cycle(8, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned')\n", "learn_lm.save_encoder(f'{lang}fine_tuned_enc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Classifier" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "data_clas = (TextList.from_df(train_df, path, vocab=data_lm.vocab, cols='comment')\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_from_df(cols='label')\n", " .databunch(bs=bs, num_workers=1))\n", "\n", "data_clas.save(f'{lang}_textlist_class')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score\n", "\n", "@np_func\n", "def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load_encoder(f'{lang}fine_tuned_enc')\n", "learn_c.freeze()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "lr=2e-2\n", "lr *= bs/48" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy_innertime
00.3381500.2752980.8998760.87843000:02
10.3023020.2459490.9029850.87722600:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy_innertime
00.3217680.2554570.8992540.87136700:02
10.3059340.2508880.8949010.87202100:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy_innertime
00.3009390.2610800.8936570.86620100:03
10.2637900.2202070.9067160.88611500:03
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-2)\n", "learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy_innertime
00.2828880.2382030.9054730.88648300:04
10.2485990.2164890.9185320.90155000:04
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-3)\n", "learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy_innertime
00.2015080.2171760.9110700.89008400:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.unfreeze()\n", "learn_c.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Competition top 3 f1 scores: 0.90, 0.89, 0.89. Winner used an ensemble of 4 models: TextCNN, VDCNN, HARNN, and SARNN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Ensemble" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)\n", "learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load(f'{lang}clas', purge=False);" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9111), tensor(0.8952))" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds,targs = learn_c.get_preds(ordered=True)\n", "accuracy(preds,targs),f1(preds,targs)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "data_clas_bwd = load_data(path, f'{lang}_textlist_class_bwd', bs=bs, num_workers=1, backwards=True)\n", "learn_c_bwd = text_classifier_learner(data_clas_bwd, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()\n", "learn_c_bwd.load(f'{lang}clas_bwd', purge=False);" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9092), tensor(0.8957))" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds_b,targs_b = learn_c_bwd.get_preds(ordered=True)\n", "accuracy(preds_b,targs_b),f1(preds_b,targs_b)" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "preds_avg = (preds+preds_b)/2" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9154), tensor(0.9016))" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(preds_avg,targs_b),f1(preds_avg,targs_b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }