{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai.imports import *\n", "from fastai.torch_imports import *\n", "from fastai.core import *\n", "from fastai.model import fit\n", "from fastai.dataset import *\n", "\n", "import torchtext\n", "from torchtext import vocab, data\n", "from torchtext.datasets import language_modeling\n", "\n", "from fastai.rnn_reg import *\n", "from fastai.rnn_train import *\n", "from fastai.nlp import *\n", "from fastai.lm_rnn import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Language modeling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0m\u001b[01;34mmodels\u001b[0m/ \u001b[01;34mtmp\u001b[0m/ wiki.test.tokens wiki.train.tokens wiki.valid.tokens\r\n" ] } ], "source": [ "PATH='data/wikitext-2/'\n", "%ls {PATH}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \r\n", " = Valkyria Chronicles III = \r\n", " \r\n", " Senjō no Valkyria 3 : Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the \" Nameless \" , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit \" Raven \" . \r\n", " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more for series newcomers . Character designer Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi Ozawa . A large team of writers handled the script . The game 's opening theme was sung by May 'n . \r\n" ] } ], "source": [ "!head -5 {PATH}wiki.train.tokens" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 36718 2051910 10797148 data/wikitext-2/wiki.train.tokens\r\n" ] } ], "source": [ "!wc -lwc {PATH}wiki.train.tokens" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 3760 213886 1121681 data/wikitext-2/wiki.valid.tokens\r\n" ] } ], "source": [ "!wc -lwc {PATH}wiki.valid.tokens" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(371, 12981, 1, 2088628)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "TEXT = data.Field(lower=True)\n", "FILES = dict(train='wiki.train.tokens', validation='wiki.valid.tokens', test='wiki.test.tokens')\n", "bs,bptt = 80,70\n", "md = LanguageModelData(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)\n", "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#md.trn_ds[0].text[:12], next(iter(md.trn_dl))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "em_sz = 200\n", "nh = 500\n", "nl = 3\n", "learner = md.get_model(SGD_Momentum(0.7), bs, em_sz, nh, nl)\n", "reg_fn=partial(seq2seq_reg, alpha=2, beta=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a88f0d1d9d1d443e81d3cdc61027f8cf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 5.8902 5.5658] \n", "\n" ] } ], "source": [ "clip=0.3\n", "learner.fit(10, 1, wds=1e-6, reg_fn=reg_fn, clip=clip)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "21dd231e1f8e48419cc3c17cf43f0ecb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 5.5214 5.3178] \n", "[ 1. 5.36 5.1435] \n", "[ 2. 5.2619 5.0643] \n", "[ 3. 5.166 4.9541] \n", "[ 4. 5.0275 4.8416] \n", "[ 5. 4.9453 4.7837] \n", "[ 6. 4.9306 4.7692] \n", "[ 7. 4.9195 4.7462] \n", "[ 8. 4.8391 4.6762] \n", "[ 9. 4.767 4.6317] \n", "[ 10. 4.718 4.5859] \n", "[ 11. 4.6623 4.5497] \n", "[ 12. 4.6262 4.5275] \n", "[ 13. 4.6103 4.5196] \n", "[ 14. 4.6206 4.5137] \n", "[ 15. 4.6716 4.5614] \n", "[ 16. 4.6237 4.5333] \n", "[ 17. 4.5918 4.4998] \n", "[ 18. 4.5519 4.47 ] \n", "[ 19. 4.5188 4.4514] \n", "[ 20. 4.4907 4.4342] \n", "[ 21. 4.464 4.4094] \n", "[ 22. 4.4473 4.4016] \n", "[ 23. 4.4196 4.3848] \n", "[ 24. 4.3956 4.3653] \n", "[ 25. 4.384 4.3596] \n", "[ 26. 4.3692 4.3493] \n", "[ 27. 4.3546 4.3422] \n", "[ 28. 4.3548 4.34 ] \n", "[ 29. 4.3457 4.3364] \n", "[ 30. 4.3517 4.3371] \n", "[ 31. 4.4461 4.401 ] \n", "[ 32. 4.4343 4.39 ] \n", "[ 33. 4.4064 4.3759] \n", "[ 34. 4.401 4.374] \n", "[ 35. 4.3683 4.3583] \n", "[ 36. 4.3535 4.3478] \n", "[ 37. 4.3403 4.3402] \n", "[ 38. 4.3303 4.3356] \n", "[ 39. 4.3214 4.3346] \n", "[ 40. 4.2934 4.3226] \n", "[ 41. 4.2919 4.3188] \n", "[ 42. 4.2742 4.3055] \n", "[ 43. 4.2566 4.3009] \n", "[ 44. 4.2374 4.2973] \n", "[ 45. 4.2297 4.2854] \n", "[ 46. 4.217 4.2823] \n", "[ 47. 4.2086 4.2732] \n", "[ 48. 4.1899 4.2642] \n", "[ 49. 4.196 4.2624] \n", "[ 50. 4.1726 4.2562] \n", "[ 51. 4.182 4.2528] \n", "[ 52. 4.1619 4.2527] \n", "[ 53. 4.1868 4.248 ] \n", "[ 54. 4.1704 4.2455] \n", "[ 55. 4.156 4.2425] \n", "[ 56. 4.138 4.2428] \n", "[ 57. 4.1476 4.2367] \n", "[ 58. 4.1454 4.2349] \n", "[ 59. 4.1273 4.2312] \n", "[ 60. 4.1331 4.2317] \n", "[ 61. 4.1471 4.231 ] \n", "[ 62. 4.1283 4.2335] \n", "\n" ] } ], "source": [ "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save('lm_420')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0113c5be9cd24acd98dddb2de59966e2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 4.2051 4.242 ] \n", "[ 1. 4.2466 4.2633] \n", "[ 2. 4.2073 4.2377] \n", "[ 3. 4.2657 4.2765] \n", "[ 4. 4.2209 4.2627] \n", "[ 5. 4.1965 4.2458] \n", "[ 6. 4.1746 4.2335] \n", "[ 7. 4.2657 4.2897] \n", "[ 8. 4.2401 4.277 ] \n", "[ 9. 4.2266 4.269 ] \n", "[ 10. 4.189 4.2546] \n", "[ 11. 4.1639 4.2369] \n", "[ 12. 4.1531 4.2288] \n", "[ 13. 4.1233 4.2223] \n", "[ 14. 4.1328 4.2161] \n", "[ 15. 4.2515 4.2909] \n", "[ 16. 4.2246 4.2872] \n", "[ 17. 4.2081 4.2727] \n", "[ 18. 4.1989 4.2651] \n", "[ 19. 4.1836 4.2634] \n", "[ 20. 4.1853 4.2647] \n", "[ 21. 4.1435 4.2518] \n", "[ 22. 4.1407 4.2388] \n", "[ 23. 4.1227 4.2285] \n", "[ 24. 4.1024 4.2244] \n", "[ 25. 4.0866 4.218 ] \n", "[ 26. 4.0916 4.2135] \n", "[ 27. 4.0695 4.2084] \n", "[ 28. 4.0615 4.2036] \n", "[ 29. 4.0672 4.2034] \n", "[ 30. 4.0859 4.2014] \n", "[ 31. 4.1937 4.2731] \n", "[ 32. 4.1961 4.2628] \n", "[ 33. 4.1725 4.2582] \n", "[ 34. 4.1909 4.2662] \n", "[ 35. 4.1622 4.2586] \n", "[ 36. 4.1525 4.2556] \n", "[ 37. 4.1601 4.2616] \n", "[ 38. 4.1521 4.2604] \n", "[ 39. 4.1471 4.2571] \n", "[ 40. 4.1131 4.2425] \n", "[ 41. 4.1072 4.2362] \n", "[ 42. 4.0949 4.2399] \n", "[ 43. 4.1014 4.2315] \n", "[ 44. 4.0926 4.2266] \n", "[ 45. 4.0721 4.2225] \n", "[ 46. 4.0694 4.219 ] \n", "[ 47. 4.0742 4.2188] \n", "[ 48. 4.0433 4.2181] \n", "[ 49. 4.0666 4.2083] \n", "[ 50. 4.0395 4.2066] \n", "[ 51. 4.0591 4.2007] \n", "[ 52. 4.0166 4.2046] \n", "[ 53. 4.0279 4.1973] \n", "[ 54. 4.0178 4.1957] \n", "[ 55. 4.0241 4.1937] \n", "[ 56. 3.9977 4.1901] \n", "[ 57. 3.997 4.1909] \n", "[ 58. 4.0194 4.189 ] \n", "[ 59. 3.9832 4.1833] \n", "[ 60. 4.0069 4.1856] \n", "[ 61. 3.9893 4.1873] \n", "[ 62. 3.986 4.1828] \n", "\n" ] } ], "source": [ "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save('lm_419')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ff63287f1ed47f2afe629863278982a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 4.0781 4.2001] \n", "[ 1. 4.1227 4.2212] \n", "[ 2. 4.1012 4.1981] \n", "[ 3. 4.1524 4.243 ] \n", "[ 4. 4.1137 4.2235] \n", "[ 5. 4.0761 4.1977] \n", "[ 6. 4.072 4.1966] \n", "[ 7. 4.1489 4.2491] \n", "[ 8. 4.1355 4.2431] \n", "[ 9. 4.1167 4.2323] \n", "[ 10. 4.0797 4.2189] \n", "[ 11. 4.048 4.2054] \n", "[ 12. 4.0284 4.2013] \n", "[ 13. 4.0436 4.1854] \n", "[ 14. 4.0416 4.1864] \n", "[ 15. 4.131 4.2443] \n", "[ 16. 4.1205 4.2414] \n", "[ 17. 4.116 4.2427] \n", "[ 18. 4.1111 4.2443] \n", "[ 19. 4.1028 4.2313] \n", "[ 20. 4.0806 4.224 ] \n", "[ 21. 4.0879 4.2192] \n", "[ 22. 4.0452 4.2083] \n", "[ 23. 4.0272 4.2042] \n", "[ 24. 4.0196 4.2019] \n", "[ 25. 4.0273 4.1909] \n", "[ 26. 4.0178 4.1868] \n", "[ 27. 3.9833 4.1847] \n", "[ 28. 3.9777 4.1805] \n", "[ 29. 3.9841 4.1805] \n", "[ 30. 4.0016 4.181 ] \n", "[ 31. 4.1057 4.236 ] \n", "[ 32. 4.126 4.2456] \n", "[ 33. 4.1018 4.2358] \n", "[ 34. 4.0977 4.2421] \n", "[ 35. 4.1049 4.2377] \n", "[ 36. 4.1079 4.2429] \n", "[ 37. 4.0851 4.2365] \n", "[ 38. 4.0778 4.2304] \n", "[ 39. 4.0849 4.2331] \n", "[ 40. 4.058 4.2278] \n", "[ 41. 4.047 4.2225] \n", "[ 42. 4.0585 4.2251] \n", "[ 43. 4.0324 4.2226] \n", "[ 44. 4.0459 4.2161] \n", "[ 45. 4.0147 4.2054] \n", "[ 46. 4.0259 4.2097] \n", "[ 47. 3.9967 4.205 ] \n", "[ 48. 4.0014 4.1984] \n", "[ 49. 3.9926 4.1985] \n", "[ 50. 3.968 4.195] \n", "[ 51. 3.98 4.1888] \n", "[ 52. 3.9527 4.1912] \n", "[ 53. 3.9452 4.1864] \n", "[ 54. 3.9587 4.1848] \n", "[ 55. 3.9286 4.184 ] \n", "[ 56. 3.9302 4.1778] \n", "[ 57. 3.9271 4.1751] \n", "[ 58. 3.9452 4.1743] \n", "[ 59. 3.9443 4.1792] \n", "[ 60. 3.924 4.176] \n", "[ 61. 3.9347 4.1745] \n", "[ 62. 3.924 4.1718] \n", "\n" ] } ], "source": [ "learner.fit(10, 6, wds=1e-6, reg_fn=reg_fn, cycle_len=1, cycle_mult=2, clip=clip)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner.save('lm_418')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "64.71545210740304" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "math.exp(4.17)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m=learner.model\n", "s=[\"\"\". The game began development in 2010 , carrying over a large portion of the work \n", "done on Valkyria Chronicles II . While it retained the standard features of \"\"\".split()]\n", "t=TEXT.numericalize(s)\n", "\n", "m[0].bs=1\n", "m.reset(False)\n", "res,*_ = m(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['the', '', 'a', '\"', 'an', 'its', 'this', 'their', 'all', 'other']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nexts = torch.topk(res[-1], 10)[1]\n", "[TEXT.vocab.itos[o] for o in to_np(nexts)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the film , and the first of the two @-@ year @-@ old , the \" black \" , a " ] } ], "source": [ "for i in range(20):\n", " n=res[-1].topk(2)[1]\n", " n = n[1] if n.data[0]==0 else n[0]\n", " print(TEXT.vocab.itos[n.data[0]], end=' ')\n", " res,*_ = m(n[0].unsqueeze(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### End" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }