{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Date Converter\n", "\n", "We will be translating from one date format to another. In order to do this we need to connect two set of LSTMs (RNNs). The diagram looks as follows: Each set respectively sharing weights (i.e. each of the 4 green cells have the same weights and similarly with the blue cells). The first is a many to one LSTM, which summarises the question at the last hidden layer (and cell memory).\n", "\n", "The second set (blue) is a Many to Many LSTM which has different weights to the first set of LSTMs. The input is simply the answer sentence while the output is the same sentence shifted by one. Ofcourse during testing time there are no inputs for the `answer` and is only used during training.\n", "![seq2seq_diagram](https://i.stack.imgur.com/YjlBt.png) \n", "\n", "**20th January 2017 => 20th January 2009**\n", "![troll](./images/troll_face.png)\n", "\n", "## References:\n", "1. Plotting Tensorflow graph: https://stackoverflow.com/questions/38189119/simple-way-to-visualize-a-tensorflow-graph-in-jupyter/38192374#38192374\n", "2. The generation process was taken from: https://github.com/datalogue/keras-attention/blob/master/data/generate.py\n", "3. 2014 paper with 2000+ citations: https://arxiv.org/pdf/1409.3215.pdf" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: faker in /root/miniconda3/lib/python3.6/site-packages\r\n", "Requirement already satisfied: babel in /root/miniconda3/lib/python3.6/site-packages\r\n", "Requirement already satisfied: email-validator==1.0.2 in /root/miniconda3/lib/python3.6/site-packages (from faker)\r\n", "Requirement already satisfied: six in /root/miniconda3/lib/python3.6/site-packages (from faker)\r\n", "Requirement already satisfied: python-dateutil>=2.4 in /root/miniconda3/lib/python3.6/site-packages (from faker)\r\n", "Requirement already satisfied: pytz>=0a in /root/miniconda3/lib/python3.6/site-packages (from babel)\r\n", "Requirement already satisfied: idna>=2.0.0 in /root/miniconda3/lib/python3.6/site-packages (from email-validator==1.0.2->faker)\r\n", "Requirement already satisfied: dnspython>=1.15.0 in /root/miniconda3/lib/python3.6/site-packages (from email-validator==1.0.2->faker)\r\n" ] } ], "source": [ "!pip install faker babel" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEABALDBoYFhsaGRoeHRsfICYlISIiIiUlJSIlLicxMC0n\nLTA1PVBCNThLOS0tRWFFS1NWW1xbMkFlbWRYbVBZW1cBERISGRYZLxsbMFc9NzZXV1dXV1dXV1dX\nV1dXV1dXV1dXY1dXV1dXV1dXV1dXV1dfV11XV2FXV1dXWldXV1djV//AABEIAWgB4AMBIgACEQED\nEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAABAUBAgMGB//EAEgQAAIBAgIEBwoNAwQDAQEAAAABAgMR\nBCEFEhMxFkFRU5GT0RQiMlJUYXGSsdIGFTM0QmNyc4Gho7LBI2KCJEPh8LPC8aKD/8QAGQEBAQEB\nAQEAAAAAAAAAAAAAAAECAwQF/8QAIxEBAAIDAAMAAgIDAAAAAAAAAAERAhIxAyEyFGFBUQQiQv/a\nAAwDAQACEQMRAD8A+fgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA+oVfgtoyFtaio33XqVfeNn8EtHJ22H6lX3i5xGDjUtrxbt\nuN3Ru76ruYnb+G/9KjtqGPwW0a5OKopyW9bSr7x04H6Ps3sEkt7dSol+4t4YOKm5qL1nvZKUHq2V\nr781dGouvbnjf/Sgj8DMA1dUE197V94zwKwPk66yr2l1Ww+vFJzcXnnC0d6a4723nD4sjz1b115v\nN5iqq+BWB8nXW1feHArA+TrrKvaW9bR8Jw1HKW5q91rWclLf+CN8ThI1Uk5SjZ3Ti81lbK+4Cl4F\nYHyddbV94zwKwPk66yr2lk9Fwe+pOStlrNPite7V/wDvot0oYBQldVaj35SldZ8nIBUcCsD5Ousq\n9o4FYHyddZV7S3qYCEpuevJOTTdmleytZ5Xf4kinBRSzbtyu4Hn5/A3ARzeH/Uq+8ceC2jr22GfJ\ntal/3HpdktVxvvd83fMibGd9yt6M+m5JWFNwU0f5P+pU94cFNH+T/qVPeLCei7tvWqK7bdpcrb/k\n64fBunrWc3d376V7ZWyKiq4KaP8AJ/1KnvDgpo/yf9Sp7xd7KXINnLkApOCmj/J/1KnvDgpo/wAn\n/Uqe8XezlyDZy5AKTgpo/wAn/Uqe8OCmj/J/1KnvF3s5chpPDt2yeTvk/aBTy+Cmjkr7D9Sp7xhf\nBXR73Ye//wDSp7xdSpS4kc6uElOKTcovPONlvTXn5Sfyqq4KaP8AJv1KnvDgpo/yf9Sp7xO+Knzl\nb115uw6xwFmn310mldrdnx7+P2chUVnBTR/k/wCpU94cFNH+TfqVPeLKro9ybbc1fkaXEl/COK0P\na2rOolfNa2VuRcgEPgpo/wAn/Uqe8OCmj/Jv1KnvE74p/vq+uvP2nfD4J072c5X8Zp9AFVwU0f5P\n+pU94yvgno/yf9Sp7xdbOXIaQw7imknvvmwKbgto69u58+Ta1L/uM8FNH+TfqVPeLjYzvuVvRn03\nI89GNu+vVXmUkl4Te78SQs0r+Cej/Jv1KnvDgpo/yb9Sp7xYQ0ZbfKpLJrOS401/P5I60cFqbr8X\nIt3o9JUVXBTR/k36lT3hwU0f5N+pU94n/FSTbTqK/EpZX5bcvYa/FH99Xz9+s92/oAhcFNH+TfqV\nPeHBTR/k36lT3iZ8TK91KqnyqSu/SdaOj5RSvKbfG72T/ACu4KaP8m/Uqe8bL4JaP5j9Sp7xbPDt\n235O+/2m+ylbcBTr4I6P5j9Sp7w4IaP8n/Uq+8XMackzpZ8gFFwQ0f5P+pV94cENH+T/AKlX3i9s\n+QWfIBRcENH+T/qVfeHBDR/k/wCpV94vdV8g1HyMCi4IaP8AJ/1KvvETFfBXBRl3tDK3OT949RqP\nkZSaZ0nsKuo4X71Pe087+Yk89Cpfwawt/kFb7yp2mr+DmE5l+vPtO0NPKU1HZtX852+NafHkY9tQ\ngT+DuFSb2PE/pz7SnrYLCJ2UbNLO7nn+Z6iOLjUjLV5LFZWw8I3na942fmfGS3TGFD3HRldKDXI9\nZ7+k9Fo/4PYSpSpTlRzlFN9/U/HjIVJq3FluPQaETjS75Nd93q5FkLJxeWw2hIPEpOlei6jXhPde\n3LfkPUQ+CeAa+Q/Uqe8VuFq/1kr/AO/7W+z8z1lLcatiVPwSwHMfqVPeMw+COAd/6H6lT3i6NqcL\nttIsMpjuZYotThGcZa0ZJNNcae5m+oaRyqN6rtvNaGtZ6zb77K6Sdst/58h31BswOEJVNo076ufE\nrWztZ8u7pZGWKxWrnQTlblWb6d3/AHiLDZjZgQe6MUm/6UZXllnay8+f5+Y6UqtdxvKCUtZZebVV\n+N8dyTsxswOOvW5PyXabbSr4iz/I6bMbMDiqtW+cFb/uZl1avNrc+Pj4jtszGzA57Wp4nH6LR7TC\nq1ebXSddmNmBy16vi8lrfnfMa9W17Z5/x/yddmNmBxU63HFXut27z8Zlzq6qair3zXG1/B12Y2YH\nFzq52XLxf8+jpG1q8cOjN7jtsxswOO0qr6Ce7zcWY2tXPvfRlud+PM7bMbMDm5VbOyV75dn/ACau\npV4o8m9Lfb0+noO2zGzA5bSr4l9+/K++3H6DCqVeOPGrWz9NztsxswOSqVbeCn/3PjMupV8RHTZj\nZgco1Kr+gle34GFOrxxz48l5t2Z22Y2YHOU6ve96t2ds2ma69bLveTi83p9J22Y2YHLaVfE/4/Mb\nWrbwF0nXZjZgc5VqieULrL/vSbU51Na0oq1t/wD3oNtmNmB0BpszGzA6g5bMbMDqDlsxswOoOWzG\nzA2q62q9Xea0XLV77wrvfb8N3msNmNmByoSqOL2itmrfzx7uL8CJXnWVS0b21oaqteLj9Nt8X/ws\ndmNQkw1jNKXTVSvGUNltdX6uMZXd90r7l/yQtL4F1aylON3s4p99ZJ2d+J3zZ6Zw855L4QRqqcFU\njUqSjDOdGpsrtt/R1l5uUlLOd4xjXESl8HpxettU3Z/Ra4jjLQla+6Pp1lYizxM4zSUsZG6btraz\nytuvvtf8yzozqKn39So3JXSm1eOWSy4xTMMUI06EXF99Le2r28xFxedprJLNJNpJkinQ325Fb8De\nnhl30c9SWa5V6Cat4zUqNSu3a2e+x6KOkHFJOnnZcfm/4I9HRMYScrtrfnZL/k6Rpa0r8vsW/wDM\nRiuUoNDDtVlO6ttFO3mu+09bRfeoppYdLNEzR1XLU5N3oLMMSsTvhuP8CJrkrA5634fyIZTYQUUo\nxSUVkklZJGTINAAABgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAI2O0hSw6TqytrO0Uk5Sk+RJZshcI8\nN9d1FX3QLYFTwjw313UVfdHCPDfW9RV90C2BU8I8N9b1FX3Rwjw31vUVfdAtgVPCPDfXdRV90cI8\nN9d1FX3QLYFTwjw313UVfdHCPDfW9RV90C2BU8I8N9b1FX3Rwjw31vUVfdAtgVPCPDfW9RV90cI8\nN9b1FX3QLYFTwjw31vUVfdHCPDfW9RV90C2BU8I8N9b1FX3Rwjw31vUVfdAtgVPCPDfW9RV90cI8\nN9b1FX3QLYFTwjw31vUVfdHCPDfW9RV90C2PFfCyNPuy8m09nHdJrjeZf8I8N9b1FX3TwPwxqwxW\nkIyg5KGyim5RlC1m75SS5RKwzh5pV1OMpyhBNXlJvN2yX/eQtNvGorrd+aKR1oJKEbKKIeMx01lT\nunxtL8jnEzMtzjEQ9XoyprX83YTIu0UnxN/8Hm9A4zUSU5JXit7L3u2lZ9/DP+5HRl3qVJS71WX/\nAHeawaV7bsor0LecJ46lGDanDWeS75XODxcI2WvHJeMs3xgTauIWvGLaV1J+z/k4z0goVI2j3qeb\n8xEjXpzjO843ytdrev8A6R6leEo+FG/pRjKZhvGIl6lVU1dPIsNETvr/AOP8njcDpGMVqSmrLc7r\noLvQ2mqNPaa0pO+rbUhKfLv1U7CJYmKeqABtkAAGAZAGAAAAAAAAADIGAAAAAAAAAAAAAAAAR+5I\n7fbPOShqRXiq93b05dCJAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGTAGQAAAMAZBg\nyBgGTAAAyBgAAAAAAAAAAAAAAOGPqOFGpKO9RdvTYDFXH0oPVlUimt63temxotJ0OcX43RRxikrI\nyeOf8mf6ez8aP7emTvmtwKvQcnapD6MWnFcl1mulfmWh6sctot5csdZoABpkAAFdpHSEsPXoayj3\nPUepKed4VH4N+LVe702ItH4Qx7+pUVqUpuGHUIznUqqPhTsuK/mJmncI6+DxFKMVKU6clFO3hWy3\n+exU6T0PUXcc6UJtUaTpyp0qipTSajnF7t8dwFlV0/howpzU5TVS+ooQnOT1fC71K6txlhRqxnGM\n45xkk1k1kzzWI0Q4UKSp4atrpznrQrxVWlOTu++eUk+MsdF4vEKdPD14qVSOHjOrUT+m5NaqVrcT\nAtwQu7avktT1qXvDu2r5LV9al7wE0ELu2r5LV9al7xju2r5LV9al7wE4EfC4yNXWVpQnG2tCStJX\n3Pkaeea5CQAAAETE6To0nUU5WdKntJ5PKF2rrl3PcRNKachSoVZ0mp1IUVVSaerqydot+k4ae0RP\nEV8PKFtTOFdZZ0tZSt0xt+LIFHQNeOBxlKSUqk0qdJXWdKGULv8AFgXmA0zQxDlGEpa0YqTUoSi3\nF/SSazXnRxl8IaDp1XHaa1Om56s6VSLkuVJq7V+Q46RwGIlXdShaMu4504yulao5JpfkyvwWia7x\nDqOlUhF4WpTvVr7VucmvO7LICxwemKtfD4apSp3lUnCNXWhNKK1daUlfi5HuJmE0zQrVXTpuTa1r\nS1JqDcXaVpWs7ek10BCpDCUqdWm6c6cYws5KV9VJa2RWaPweIp4xOnRnQouU3Wi6sZ0pXTs4R3p3\nz4gPSAjV8TUjK0aE5q29Sgl6M2jn3bV8lqetS94CaCF3bV8lq+tS9474XFRqpuN007SjJWlF8jQH\nYGQBgGTAHB4uG22N/wCpqa9rPwb2vf0kKfwiw0Y0pa0ntYylTUac5SkouzskrnHSNGvTxkcRSo7a\nMqDpSipKLi9ZSUs+Ip8HSr4WrgIbHaVYYaupwUop2c4u6byfEBfVsdWrU6c8DsZQndupUcrRtxaq\ns73vxq1jbQWkZYmlOU1FShUlTbg24TcfpRb4s/yIEdGVY6Nq0pU9epVnOUqcZqNlUqXcdbzJl9Qo\nxpwjCEVGKVkkrJAdAAAAAGDJgAZAAAAAAAAAAAAAAAAAAAiaU+b1fsMlkTSnzer9hknix1Ryds3k\nDhpCiqlCpByUVKLTk9y85XaOxkp1Ix7rp1MvB2Ti5rzO+f4HzYxuLfTnKpp6nQnhVf8AH+S2KnQn\nhVf8f5LY9/i+IfP8v3IADo5gAAGDJgCNg8fSruoqU1LZzcJ24pLiJJA0VSjF4jVil/WluSX0Yk8A\nAZAwAAIMssbDz0J3/Ccbe1k4hVPntP7mp++BOAwDIAwZAAw2RsBj6WJp7SjLWheSv507MktX3lfo\nKnGOHSiklr1dyt/uSAsAZAAAAYINDLGV1xOlRl+N6i9iRPIFH57W+5o/vqgTwAAAAGDnWlCCdSdk\noxd5PijveZ1OWIinTkmk04u6ea3Aa4LFwr0oVab1oTV4vzHYiaIglhaCSSWyhklb6KJgGLg87DHU\nXNxdCnZW1rWco33Nq1uhlo8FS5qHqozsTE49Tri5A7ipc1D1UO4qXNQ9VE3S08EDuKlzUPVRnuOj\nzUPVQ2LTgQIU1TqU9Raqm3GUVu8FtO3LkTpSStdpXdl53yGom1ZBX0aUZxU5pSlLPPO19yXIjEqV\nJO2zh0IzslrEEFYak/8Abj6qHctLm4dCG5acCBLDUkvk4dCNI0qTfycOhDctZAg9y0ubj6qHctLm\n4+qhuWnAg9y0+bj6qHctPm4+qhuWnAg9y0+bj6qNsPHUqqMcoyi3q8SaazXJvLGVlppE0p83q/YZ\nKbPP6XrVq93h6lsPTT2rsrVH4sHbi43uzLlxrHqBpOEZYerGctSLi7y5FykPY1XUoOvOkowl3mom\nnOWq0t+7K+R107h6UqE51YayhFtWdn0nFYXD0KmHaorWnKyes3qvVbuuXc8z5+PH0cuvT6E8Kr/j\n/JbFFo2lUlKpqVdTKN+9Ur7+UnOnioZqpTq2+jKGo3/km7dB7fF8Q8Pl+5ThcqKWPp4uTjFu1Nf1\nKb3xnfwZcTtn5jrPD019CPQjrEOMzSyuLlZCjTf+3HoRv3LT8SPQi6pusLi5X9y0/Ej0I5ulTvbU\nj0Iam7vo3fiPv5ftiTStjhaVrqnHP+1Ge5afNx6ENTdYgru5afiR6EO5afiR9VDU3WIK7uWn4keh\nGlajGEZSglGUU2mstyvZ8q8w1N3efz2n9zU/fAmkKd+6qcrd7sZ3fEu+hx9PQa1rVKjv30Ixi0uJ\nt3z85lqZpPBWzoUl/tw6EYhQpP8A24dCNUmyzFyv7lp83H1UO5qfNx9VDVN4WBB0L83X26v/AJZH\nHZUr22cOhG0cLTSypw9VDU2WQK7uanzcPVQ7lp83D1UNTdYgru5afNw9VDuWnzcPVQ1N1iQaPz2t\n9zR/fVOcqap2nBKLTW7JNXV0+U6QWri603lHY0ld5K6lUvn+K6STFNRNpwuV84qpObmtZKWrFPck\nks7cuZpOlTX+3DoQpNlncXK2NCm/9uPQjPc1Pm49CLqmyxuaVn3kvQ/YQe5qfiR6Ec9nTvbZx6EN\nTZL0V81ofdQ/aiUVywlLm4eqh3LT5uPQhqbqDCYGi8ZWopU41dWLqOKzcXvUcsm+PPjRefFdDxP/\nANS7TWkv9ZWy/wBql+6oTTztZZzl1E+K6Hif/qXaaVsFhqcdacdWK43KfaTyBpfAuvScYtXtJWe5\n6yt+DIkd9q/EU6MtTZ0allNOT1mrxzus5eglUKWEqPVUHGXJJzV7cjvZlHX0VWUo68pxi0oKKjJp\nvVskmnyq5YaN0RVjUcpXUXLWd97ak3kr5clyu2WOFdWM9GUdel3m+b+lLxJec64rQGGqqKlB2jJS\nVpy3rdxnap8pR+2/2SJp0x45QrsJRjsoZfRXG+QVYJOK8+RvhPkofZXsN574+n+GcmWNkuNfmxsY\n8ntNwBpsVyHOlTTbWWTzJCNIb5en+AGyjyfmxso8n5s3AHOUIrf7WaOMbrJ24zetC6/A4bN34/MF\ndowg+L82cqmCp1KsFON1qze9/wBp0pUmt5vH5aH2Z+2JcekdaLQ2G5qL8zbafpTZvpKKWGqJJJKD\nsluRLIulPm9X7LOuXG8ew89ja0adKpOa1oqLut91yFRo90qdRXw6p1NrslaWvq3hr5X3ZZZFxiou\nVOaUVNtNasnZS8zZTaPWHp4iMJU5068m9VTm6ivazad3xLez5+HzL6OX1D12hPCq/wCP8lqVWhfC\nq/4/yWx7fF8Q8Pm+5VcsFShXlKFOMHOCcnFJazvvdt50nTVssjpW+W/w/wDYHeOPPl1woQTinvXE\nddmuT82YpeCvQblZabNcn5s5VIRTjkt+R3NZ74+n+GVGFSXGZ2a5PabgDTZx5PzZo1HiTfSdWsiP\nKlIK2io2zWf4mMTTWynl9CXG+RmI03kzeurUpr+yXsZCnCpo6i60aWp3s6M2++le6cUuP+5jB6Np\nYdypU4tQjGFrybfGSX87pfcT/dA2n8tU9EP5MY9dMuOVWCUW9ytmZpwTSe9cRvU8F+hm0dyOjm12\nceT2jZrkNgBHlBKSWV3ex12a5PaJeEvxNwNNmuT82HTXJ+bNzEldWCOUkmskZSjye05um15ln+Zt\nGk7kWmMVRi6ck1k7Le+VHCroqhWq1sPKH9PZU3lKV++lNPj/ALUS8R4D9Mfajaj89rfc0f3VTOTe\nHHDCYSFJSpwTUYysrtvKy43mbYiCS5FxnWHhVPtv2IVfBZqGZ6xGmuPMzso8n5s2RkqNNlHk9py2\na17ZXsSDT6X4AY2UeT82NmuT82bgDy+GqVFXqPvrK2zlbObvmm+PN2XLc9YVNGVCOKm1qKKpwcXx\nKTc9a3n3FjDE05XtOLsrvPcuU8cu+eW3IYxOIjSjrSu80kkruTe5JcpDxmNxUaUpU8LrSytF1I3e\nfIu0xjMTTdXDPXTUZybd8l/TklfpLJO6uuMW5vK6QrVp1ablF317VI211SSatqvlz/NviPQ6Om5U\nk2285JN8aTaTOOkqEG6TcU26sYt8sc8nyonpWyWSRbdMs4mIinOp8pR+2/2SJpCqeHR+2/2SJp0x\n4zCDhPkofZXsN574+n+GaYT5KH2V7DabV4+n+GcmW5jWzyVzOsuU0pPL8X7QMSlLxcuW5zpN3/H8\niRrI5wsm/wDvEB1MSlYay5TRyWss+JgJTlxR/M4uT1vb5mSjnNd9H8fYBtF5IxH5aH2Z+2JuaR+W\nh9mftiax6sJZF0r83q/ZZKIulPm9X7LOmXJbx7ChrwUoSjJtRazadsvTxHLDYOlRypwjG/Hxv8d7\nOGmcIqtGV60qUVF6zXgtf3EHB7KdWl/WxNRxd4KcJKF7PO+quK586I9PpTNTx63QvhVf8f5LYqdC\n+FV/x/ktT3eL4h4fN9yiVvlv8P8A2ArfLf4f+xhs7w82XWtLwV6DMpWMU33q9BhyWss+J/wVklKV\nso/mcpttr8/MSNZGk7XXpA2g8jITRrUfev0MIKTe5ZdBynOVt1nxcdzunkJbgrlR/gzifk5/Yl7G\nb01kvQaYn5Of2JexgZfzul9xP90DaXy1T0Q/kw/ndL7if7oGZfLVPRD+TEddMuMVPBfoZsty9BrU\neT9BmLyOjmSlbzvkNZTkrWjx8oUlrPPiNwI13rZ//GSIvJGkray/E3uBluxq5Pij+Ziru6PabgRq\nk5NZqz5DtSYrLvX6DdIDnifAfpXtRtR+e1vuaP7qprifAfpX7kbUfntb7mj+6qYybw4xDwqn237E\nZq+CxDwqn237EKjyZqGZ62NZStkldm2suU0jJXefIVGJTlxR/M5pvWJFzTLW/Ag3MN2VzKZrNrL0\nlHnZVsbKMN0b6ilZRus3eW/kUcvOWVGcpQ7+2v3PPWStk+9uaxlFYidR4WrqunGK/pR3qUm+Pzok\nxxcVuw1Zeikl/J4I8dTb0TncUg2rba7n/SzyUo6urqKytv1ta75LFxhvkqf2I+xEKnR2tWEnQ2dO\nnrPv1FOcmrbluSTe8sS446s55bImkP8AZ++h7GTCNjqUpQThZzhOM0m7J24r8WVzTu98dCunyaif\n5pmmHep4dH7b/ZImlPUx/f0v6FfKfif2S85vi9NOmotYbES1pqNlDPPj3nTHjUO+E+Sh9lew1knx\nf9zGDqf0oZS8FcXmNp5td69/J5jllDI934sxC98+U6a39r6DSm7LwXvfF5wNJX4uQ1vN2O+t/a+g\nwnm+9fR5iUMcT/A1zu7nTW/tfQIxu72sKHQ0n4Ufx9huaTTyaztxGlbmkflofYn7Yjaf2y6DhVxW\npVg9nUl3s1aMbvfE1j0hZFdpvG0qdKUJ1IxnUi1CLa1pPzLjNlpKTyjhq7fniorpbIeP0ftITr4h\nRdSEXsorNUuWz42+N+ZHSeN49hTaZpuVCf8AU1Keq9fvNe68xX0NL6jhtMROUPPhpR1lbl/Mt9IQ\nhKhUVRuMHF6zW9LlKjDYmFatSjPFRqarbhCNNw1nqtZt+Zs8GHvH2+hn6y9PU6IxlKLm5VIR1lBq\n8krrPNXLeliKc/AnGX2ZJ+wqdB4amteOpG0VBRTSdlZ5FhX0dRms6cU+KUVqyXnTWaPZ4viHi8v3\nLWv8t/h/7Gs3ycjK/A1MUqtSOJhlFJU5qzc438KVuPoJ+uvFfQd448+XWIefzmM7rPLIzCVklqvo\nMN5p6rtZ8XoFM2y92W+xq5Sub668V9BrJ7u9e/kLRZBu2fmMzvxeczrrxX0GJPWVknnxslDeHH6T\nL3MJAoxT8FehGmK+SqfYl7GZjJpJNPLkNMTP+lPJ+BLi8wG7+d0vuKn7oGZ/LT9EP5I08X/UjV2N\nfvaUo22e+7i73v8A2/ma4HHOveq6VSnrRg9Was1vMR10y4lVN+Xm9piKyd99hKWT719BlT/tfQam\nHNrnffkZnfi5EYv319V2tyG2v/a+gUW560jenfK+/wD4Dlmu9fQba39r6BQ0d8rbjrHcaPPK1s97\nOhVaVfBfoNjEldNcprrvjT/AqMYnwH6V+5G1H57W+5o/uqnHF1LU29WTtbJLPejjiNIbGVbEujXf\n9KC1dS3gOb3349b8jGTeCZDwqn237Eazvnbl/g5YXE7RSnqTjrSvaUbSWSya4jpN3Xgu/oLXpmes\nrdn5jGed2ba68V9BrF2b719ApGZvk33/AINNaZ0cl4r6DW6vfVe7kFDeD5d+ZzW9cnpRvr/2voFt\nZrKy84oaV8YqUHOpCpGEVdtqNkuk328uaq9EfeNfhBQhUwdaM4qS1b2fKtzJ9KCjFRirRSSS5Ety\nOWkOtIe2lzNXoj2jby5mr0R7SeBpBSBtpczV6I9o20uZq9Ee0ngaQUg04TnUjJwcIwbffWvJ2a3J\nvLNk0yCxFKgQjKmtTUlJLJONs1xXz3mdrLmqnQu0nAmsJSDtJc1U6F2jay5qp0LtJwJpBSDtJc1U\n6F2nOvjFTSc4zinJRTernJuyW8sir0/hadWlTVSCklWpWvxXmk/ybGkFO21fN1Ohdo2r5up0LtJo\nGkFIW1fNVOhdo2suaqdC7SaZGkFIO1lzVToXab4eEpT15RcUk1FO13e127btyJZgsYxBQRdKfN6v\n2GSyJpT5vV+wyzxrHqgruShLVajKzs5bl52VGFrd0VqeviaM9m3NRpxabdmt7e7PiLPSFFVKFSDl\nqqUWnLk85UYbEKvWoKVTDrZyvHZNuU3qtW3ZLzHz8OTL6OfYh6/QnhVf8P5LUqtC+FV/w/ktj2+L\n4h4fL9yiYqnLWU4rWyaaVr77pq5z15c3U6F2k4HW3KYtA15c3U6F2jXlzdToXaWALsmsK/XlzVTo\nXaNeXN1OhdpYGGr5DY0hV4fGKrFypxnJKTi2kt8XZrfynXXlzVToXacvg9hadKjNU4KCdetdLLdU\nkl+SS/AtBsaQga8uaqdC7SNpCFWpRnCmqlObtqySXetNPl8xcAbGsK2lOpqx1qc9ayvZLfx8ZmcZ\nzi4KnKOsrOUrJJPe9+bLIDY1hAmv9ZS+4qfvpm2IpyVRzUXKMkk7Wumr8u/eYqfPaf3NT98CcSJp\nZi1frS5up0R7Rry5up0LtLAF2lNIV+vLm6nQu0a0ubqdC7SwA2k0hXupJZ7Op0R7TlhcYq0FUpRl\nOEt0klZ525SyqU1OLjJXjJNNPjT4iv8Ag9h4UsMo04qK16mS3fKSXsSG0msN9eXNVOhdo15c3U6F\n2lgYGxrCBry5up0LtGvLmqnQu0sANpNYV2pOpaOpKMbpycrLJO9kr7zNH59W+5o/vqlgQKPz2t9z\nR/fVJM2sRTM4ShOTUHOMnfvbXTtZpp+gxry5qp0R7ScBaawg68uaqdC7RrvmqnQu0ngWawga75qp\n0LtNKuI1IynKE4xim22lZJb3vLIjaQoQqUakJxUouLunudsxZrCNRxGvCM4QnKMkmmkrNPc95try\n5up0LtN9DUowwlCMEox2cXZbs1dk0uxrCDpr5pW+wyac6qhNOErNSycb7zeM09zTtvsZabA1U07p\nNNrf5jYAAAAAAAESOPg8RLD/AE4wU3utZtpL05MCWDVTTdrq63obSN7XV+S4GxB0v8nD76l/5Iky\nM01dNNeY0qRhUjnaUU09+V07r8wOhk02kfGXSNpGyesrPc77wNwY1lykfGYyFGm6ktZpWyhFyk7u\nyskBJBW4DTEK1R0nTq0aqjr6lWOq3G9tZWbTzLEDJE0p83q/YZLImlPm9X7DJPFx68tpelOdGUYy\npxg4tTc75LlTRjDqvBwVR4fVeXeqSby+jc30lo6OJhqylOOTXeya38q4znR0RThOE9apJwd461SU\nkna17M+dExrT6UxO1vQaE8Kr/h/JbFToXwqv+P8AJbHu8XxDweX7lgyAdHMBgxKaW9pel2A2BpKp\nFb5JX5WZjNPc0/Q7gQtDfJT+/r/+aRPOGEoxpxai7pznL8ZScmulnRVYt2Uk3yXQG4ObrRWTlFP0\no3jJNXTugMgACBWdsbRvulSqpenWg7dF+hk45YnDQqx1Zq6vdZtNPlTWaZH+LVztfrZATQQ/i1c7\nX62Q+LlztfrJATAQ/i1c7X62Q+LVztfrZATG7Zsg6F+bQfjOcl6JTlJfk0Zei4SynOrOPHGVSTi/\nSuMmpWVlkgMgAAAABX08sdUv9KhT1fPqzqX6NaPSWBHxWEhVtrJ3jnGUW4yj6GswO4Ifxaudr9bI\nfFy52v1sgJgIfxcudr9bIfFq52v1sgJhxxk1GlUlJ2ShJt/gcfi5c7X62Q+LKbac3UqWd0pzlKN+\nW25gdNHxccPRTVmqcE159VEkwZA8jDRqdPSVanTvittWVOdu+Xepd6+LeyPoPDXc54apRU1hpRdO\nlTnBuTXeud8tZNenM9qYSA8Z8FsOnXoyhUoxnCnLbQhTqRqSurPatuzkpZ55ntBYAAAAAAA8zVoU\nY6Ym3CCqzw8XSbjm6ic02ny2t+B6YWA8PoqNLXwSpRkscqn+qdnratpbTaPjTdrfhY1q6MpSo1qr\npray0i4udrS1HWUWk+RpvpPdWFgPG47CbL4wpYeDhS/00pQpprvG3tdVLjcVxEfGRounj+4o2wzw\nq1tVNQdXWy1Vy6u/8D3QsB5jSeh8PHuCjGjFU54i84pZSexlnLl3I5aVoYWninHF00sMsOlh1qtw\nUtaWuopfS8E9YGgPEV8PXjQwbantMRReFqN+HFTknCUvOo63SWPwWoz2tWVRP/TQjhYN8ag23L8V\nqdB6YAUvwdoKonjJVHVq1NaKk04qEFN2hGLSssl6bF0EjIA4Yyi6lKcFvlFpek7mAPL66u0+9kt8\nXk0Zc0s210npZ0oy8KKfpSZqsPBZqEU/so8v437ev8n9IWhqMlGc5JrXa1U99kt/5ssjBk9OOOsU\n82WW03IACsouOqSjFauSb76Vr6q5bHGvh23CcVGolG2rN7/7k+UnnCrg6c7NpppWTi3HLkyAi1VT\nnhZNQS1YySTS71rfY6zcaOHcoxSeqty3tqyJCw8FDZ2721rGZUItRTWUWmvStwEDRk4xlKnF3VlJ\nXTWdrS3kSMb0oQ2KUptqNR23338pdSpRclJrON7P07zR4WDgoW71O6zeTvfeBFx1CGvRvGLbnm7L\nPJ7zeu3TqQVPe8nTW7V8bzWJU6UZOLa8F3XpNtVXvbPlAyZAAwZMGQAAAAAAAAAAAAAAAAAAAAAA\nAAAAAwCB8eYPyqh1kO0fHuD8qodZDtAnggfHuD8qodZDtHx7g/KqHWQ7QJ4IHx7g/KqHWQ7R8e4P\nyqh1kO0CeCB8e4Pyqh1kO0fHuD8qodZDtAnggfHuD8qodZDtHx7g/KqHWQ7QLAwQPj3B+VUOsh2j\n49wflVDrIdoE8ED49wflVDrIdo+PcH5VQ6yHaBYAr/j3B+VUOsh2j49wflVDrIdoE8ED49wflVDr\nIdo+PcH5VQ6yHaBPBA+PcH5VQ6yHaYwuk4167jQlGpShF6845rXbyinueV2/wAsDJgAZAAAAAAAA\nAAAAADBkAAAAAAGDIAAAAAAAAAAAAAAAAAAAAAAAAAAAAcu5qfiR9VDuan4kfVR0AHPuen4kfVQ7\nnp+JH1UdQBy7np+JH1UO5qfiR9VHUwBz7mp+JH1UO5qfiR9VHQyBy7mp+JH1UO5qfiR9VHUAcu56\nfiR9VDuen4kfVR1AHLuan4kfVQ7mp+JH1UdQBy7np+JH1UO56fiR9VHUAcu56fiR9VDuen4kfVR1\nAHLuen4kfVRvGKSskkuRGwAwZAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAByxNXUpzna+rFyty2Vw\nOoPL0fhNiVQhiauCthpRUnOnUU5Rj4zjZOxZ6X0xsMLHEUoqqpOCgtaykptJO/4gWoKjBYzHSqxj\nWwlOnTd9aarKTWWWVs87HPHaaq90Sw2EoKtUgk6jlNQhC+5N8oF2CBozHVKtKU69GWHlBtSUmmsl\n4UXxoqF8I8VUg8RQwTnhVe0nNRqTit8ow5APTAqcXpuEdHyxtJa8dRSSbtfNKz5CGvhDXpOEsZhN\nlSm0lVhUVSMW92tyID0QKfTGlqtGtQo0KMa06ym1eeolqpPfbznfRmJxU3LunDxopJarjUU7vj4s\ngLEHm6OncZWlW2GDhOFOrOnrOso3cXbc0WmhtKLF0nPUdOUZyhUg83GcXZq/GBYAADz2iNPTrYnY\nSlSmpU5TjKnGpG1mlZ629Z70RqVe+AwMorUUsVBWUpPLaSyu3d7i0wWhHSrRrTxFWtOFOVOOvqpK\nLafElnkjaGhKcaFGipS1aNRVIvK7ak5WeW7MCzMmDIAAAAAAAAAAAAAAAAAAADBkAAABgyAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAjaR+b1vu5/tZJNZRTTTV08mnxoDyGB03hqWh6cHVhKo6GoqUWpTlJ\nqyjqrPexpnCyo6Dw9Kq3GUXQUnezj3yvn5j0uH0VhqUtanh6MJcsacYvpSO2Iw1OrHUqwjUg/oyS\nksvMwKrQtTC0b06eNeIlOV0qlaM5btysRMBjKeF0hjadeUabqyhUpym0lOOqk0m+RlxQ0Phac1On\nhqMJrdKNOKa9DSO2LwNGukq1KFRLcpxUrdIEKri6eOwmJjhpqd4VKest2s48T496K3Qmn8LS0dT2\nlSEJUaepOnJpTUoqzWrvzsejo0YU4qEIxhFblFJJfgiPU0Xh51NrKhSlU8dwi5dIHlu5p0vg5UU4\n6rlCU9XxVKpdLoZ3+EOlaFbALC0akK1esoQhCDUne6bbtuSsz1VajGpFwnGM4PJxkk0/SmccLo6h\nRbdKjTpt73CEY+xAeb+EtGHdej4Va8qMVCqnUjNQatGPG+UvtE16GoqNHEKu4LNuopztfe7HfFYC\njWttqVOpbdrxUrX32uYwujqFFt0aNOm2rNwhGLa5MkB43CYXGbLH1sLiZQcMTXtSUItSalnm+Ox6\nP4K0qSwVOdKUpqrepOUraznLwr28+X4FpSoQhfUhGOs3KWqkrye9vlYoYeFOOrThGEbt2ikld73Z\nAdQAB5lfDzR/OT6uRnh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj\n+dn1cuw+SAD63w80fzs+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6u\nXYOHmj+dn1cuw+SAD63w80fzs+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDz\nR/Oz6uXYOHmj+dn1cuw+SAD61w80fzk+rl2Dh5o/nJ9XLsPkoA+tcPNH85Pq5dg4eaP5yfVy7D5K\nAPrXDzR/OT6uXYOHmj+cn1cuw+SgD61w80fzk+rl2Dh5o/nJ9XLsPkoA+tcPNH85Pq5dg4eaP5yf\nVy7D5KAPrXDzR/OT6uXYZ4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj+dn1cuw+SAD63w80fzs+rl2D\nh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj+dn1cuw+SAD63w80fz\ns+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj+dn1cuw+SAD6\n3w80fzs+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj+dn1cu\nw+SAD63w80fzs+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dg4eaP52fVy7D5IAPrfDzR/Oz6uXYOHmj\n+dn1cuw+SAD63w80fzs+rl2Dh5o/nZ9XLsPkgA+t8PNH87Pq5dhh/DzR/OT6uR8lAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAH//2Q==\n", "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import YouTubeVideo\n", "YouTubeVideo(\"_Sm0q_FckM8\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "import random\n", "import json\n", "import os\n", "import time\n", "\n", "from faker import Faker\n", "import babel\n", "from babel.dates import format_date\n", "\n", "import tensorflow as tf\n", "\n", "from keras.models import Sequential\n", "from keras.layers import LSTM, Embedding\n", "\n", "import tensorflow.contrib.legacy_seq2seq as seq2seq\n", "from utilities import show_graph\n", "\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fake = Faker()\n", "fake.seed(42)\n", "random.seed(42)\n", "\n", "FORMATS = ['short',\n", " 'medium',\n", " 'long',\n", " 'full',\n", " 'd MMM YYY',\n", " 'd MMMM YYY',\n", " 'dd MMM YYY',\n", " 'd MMM, YYY',\n", " 'd MMMM, YYY',\n", " 'dd, MMM YYY',\n", " 'd MM YY',\n", " 'd MMMM YYY',\n", " 'MMMM d YYY',\n", " 'MMMM d, YYY',\n", " 'dd.MM.YY',\n", " ]\n", "\n", "# change this if you want it to work with only a single language\n", "LOCALES = babel.localedata.locale_identifiers()\n", "LOCALES = [lang for lang in LOCALES if 'en' in str(lang)]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def create_date():\n", " \"\"\"\n", " Creates some fake dates \n", " :returns: tuple containing \n", " 1. human formatted string\n", " 2. machine formatted string\n", " 3. date object.\n", " \"\"\"\n", " dt = fake.date_object()\n", "\n", " # wrapping this in a try catch because\n", " # the locale 'vo' and format 'full' will fail\n", " try:\n", " human = format_date(dt,\n", " format=random.choice(FORMATS),\n", " locale=random.choice(LOCALES))\n", "\n", " case_change = random.randint(0,3) # 1/2 chance of case change\n", " if case_change == 1:\n", " human = human.upper()\n", " elif case_change == 2:\n", " human = human.lower()\n", "\n", " machine = dt.isoformat()\n", " except AttributeError as e:\n", " return None, None, None\n", "\n", " return human, machine #, dt\n", "\n", "data = [create_date() for _ in range(50000)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See below what we are trying to do in this lesson. We are taking dates of various formats and converting them into a standard date format:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('7 07 13', '2013-07-07'),\n", " ('30 JULY 1977', '1977-07-30'),\n", " ('Tuesday, September 14, 1971', '1971-09-14'),\n", " ('18 09 88', '1988-09-18'),\n", " ('31, Aug 1986', '1986-08-31')]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[:5]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = [x for x, y in data]\n", "y = [y for x, y in data]\n", "\n", "u_characters = set(' '.join(x))\n", "char2numX = dict(zip(u_characters, range(len(u_characters))))\n", "\n", "u_characters = set(' '.join(y))\n", "char2numY = #TODO: complete " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pad all sequences that are shorter than the max length of the sequence" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "31, Aug 1986\n" ] } ], "source": [ "char2numX[''] = len(char2numX)\n", "num2charX = dict(zip(char2numX.values(), char2numX.keys()))\n", "max_len = max([len(date) for date in x])\n", "\n", "x = [[char2numX['']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]\n", "print(''.join([num2charX[x_] for x_ in x[4]]))\n", "x = np.array(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1986-08-31\n" ] } ], "source": [ "char2numY[''] = len(char2numY)\n", "num2charY = dict(zip(char2numY.values(), char2numY.keys()))\n", "\n", "y = [[char2numY['']] + [char2numY[y_] for y_ in date] for date in y]\n", "print(''.join([num2charY[y_] for y_ in y[4]]))\n", "y = np.array(y)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x_seq_length = len(x[0])\n", "y_seq_length = len(y[0])- 1" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def batch_data(x, y, batch_size):\n", " shuffle = np.random.permutation(len(x))\n", " start = 0\n", "# from IPython.core.debugger import Tracer; Tracer()()\n", " x = x[shuffle]\n", " y = y[shuffle]\n", " while start + batch_size <= len(x):\n", " yield x[start:start+batch_size], y[start:start+batch_size]\n", " start += batch_size" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "epochs = 2\n", "batch_size = 128\n", "nodes = 32\n", "embed_size = 10\n", "\n", "tf.reset_default_graph()\n", "sess = tf.InteractiveSession()\n", "\n", "# Tensor where we will feed the data into graph\n", "inputs = tf.placeholder(tf.int32, (None, x_seq_length), 'inputs')\n", "outputs = tf.placeholder(tf.int32, (None, None), 'output')\n", "targets = tf.placeholder(tf.int32, (None, None), 'targets')\n", "\n", "# Embedding layers\n", "input_embedding = tf.Variable(tf.random_uniform((len(char2numX), embed_size), -1.0, 1.0), name='enc_embedding')\n", "# TODO: create the variable output embedding\n", "output_embedding = \n", "# TODO: Use tf.nn.embedding_lookup to complete the next two lines\n", "date_input_embed = \n", "date_output_embed = \n", "\n", "with tf.variable_scope(\"encoding\") as encoding_scope:\n", " lstm_enc = tf.contrib.rnn.BasicLSTMCell(nodes)\n", " _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32)\n", "\n", "with tf.variable_scope(\"decoding\") as decoding_scope:\n", " # TODO: create the decoder LSTMs, this is very similar to the above\n", " # you will need to set initial_state=last_state from the encoder\n", " lstm_dec = \n", " dec_outputs, _ = \n", "#connect outputs to \n", "logits = tf.contrib.layers.fully_connected(dec_outputs, num_outputs=len(char2numY), activation_fn=None) \n", "with tf.name_scope(\"optimization\"):\n", " # Loss function\n", " loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([batch_size, y_seq_length]))\n", " # Optimizer\n", " optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[None, None, 32]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dec_outputs.get_shape().as_list()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[None, 32]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "last_state[0].get_shape().as_list()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[None, 29]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs.get_shape().as_list()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[None, 29, 10]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "date_input_embed.get_shape().as_list()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train the graph above:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_graph(tf.get_default_graph().as_graph_def())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 Loss: 1.281 Accuracy: 0.5523 Epoch duration: 6.781s\n", "Epoch 1 Loss: 0.800 Accuracy: 0.6977 Epoch duration: 7.516s\n", "Epoch 2 Loss: 0.627 Accuracy: 0.7812 Epoch duration: 6.762s\n", "Epoch 3 Loss: 0.541 Accuracy: 0.7898 Epoch duration: 7.199s\n", "Epoch 4 Loss: 0.467 Accuracy: 0.8266 Epoch duration: 6.352s\n", "Epoch 5 Loss: 0.368 Accuracy: 0.8781 Epoch duration: 6.993s\n", "Epoch 6 Loss: 0.318 Accuracy: 0.8938 Epoch duration: 8.078s\n", "Epoch 7 Loss: 0.283 Accuracy: 0.9055 Epoch duration: 7.166s\n", "Epoch 8 Loss: 0.242 Accuracy: 0.9227 Epoch duration: 5.982s\n", "Epoch 9 Loss: 0.241 Accuracy: 0.9055 Epoch duration: 7.145s\n" ] } ], "source": [ "sess.run(tf.global_variables_initializer())\n", "epochs = 10\n", "for epoch_i in range(epochs):\n", " start_time = time.time()\n", " for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, y_train, batch_size)):\n", " _, batch_loss, batch_logits = sess.run([optimizer, loss, logits],\n", " feed_dict = {inputs: source_batch,\n", " outputs: target_batch[:, :-1],\n", " targets: target_batch[:, 1:]})\n", " accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])\n", " print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f} Epoch duration: {:>6.3f}s'.format(epoch_i, batch_loss, \n", " accuracy, time.time() - start_time))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Translate on test set" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on test set is: 0.882\n" ] } ], "source": [ "source_batch, target_batch = next(batch_data(X_test, y_test, batch_size))\n", "\n", "dec_input = np.zeros((len(source_batch), 1)) + char2numY['']\n", "for i in range(y_seq_length):\n", " batch_logits = sess.run(logits,\n", " feed_dict = {inputs: source_batch,\n", " outputs: dec_input})\n", " prediction = batch_logits[:,-1].argmax(axis=-1)\n", " dec_input = np.hstack([dec_input, prediction[:,None]])\n", " \n", "print('Accuracy on test set is: {:>6.3f}'.format(np.mean(dec_input == target_batch)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's randomly take two from this test set and see what it spits out:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25 Nov 2008 => 2008-11-25\n", "october 5 1995 => 1995-10-05\n" ] } ], "source": [ "num_preds = 2\n", "source_chars = [[num2charX[l] for l in sent if num2charX[l]!=\"\"] for sent in source_batch[:num_preds]]\n", "dest_chars = [[num2charY[l] for l in sent] for sent in dec_input[:num_preds, 1:]]\n", "\n", "for date_in, date_out in zip(source_chars, dest_chars):\n", " print(''.join(date_in)+' => '+''.join(date_out))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }