{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM (Long Short Term Memory)\n",
    "\n",
    "There is a branch of Deep Learning that is dedicated to processing time series. These deep Nets are **Recursive Neural Nets (RNNs)**. LSTMs are one of the few types of RNNs that are available. Gated Recurent Units (GRUs) are the other type of popular RNNs.\n",
    "\n",
    "This is an illustration from http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (A highly recommended read)\n",
    "\n",
    "![RNNs](./RNN-unrolled.png)\n",
    "\n",
    "Pros:\n",
    "- Really powerful pattern recognition system for time series\n",
    "\n",
    "Cons:\n",
    "- Cannot deal with missing time steps.\n",
    "- Time steps must be discretised and not continuous.\n",
    "\n",
    "![trump](./images/trump.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import re\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization, LSTM, Embedding, TimeDistributed\n",
    "from keras.models import load_model, model_from_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def chr2val(ch):\n",
    "    ch = ch.lower()\n",
    "    if ch.isalpha():\n",
    "        return 1 + (ord(ch) - ord('a'))\n",
    "    else:\n",
    "        return 0\n",
    "    \n",
    "def val2chr(v):\n",
    "    if v == 0:\n",
    "        return ' '\n",
    "    else:\n",
    "        return chr(ord('a') + v - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>text</th>\n",
       "      <th>created_at</th>\n",
       "      <th>favorite_count</th>\n",
       "      <th>is_retweet</th>\n",
       "      <th>id_str</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>i think senator blumenthal should take a nice ...</td>\n",
       "      <td>08-07-2017 20:48:54</td>\n",
       "      <td>61446</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946617e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>how much longer will the failing nytimes with ...</td>\n",
       "      <td>08-07-2017 20:39:46</td>\n",
       "      <td>42235</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946594e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>the fake news media will not talk about the im...</td>\n",
       "      <td>08-07-2017 20:15:18</td>\n",
       "      <td>45050</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946532e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>on #purpleheartday💜i thank all the brave men a...</td>\n",
       "      <td>08-07-2017 18:03:42</td>\n",
       "      <td>48472</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946201e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>...conquests how brave he was and it was all a...</td>\n",
       "      <td>08-07-2017 12:01:20</td>\n",
       "      <td>59253</td>\n",
       "      <td>false</td>\n",
       "      <td>8.945289e+17</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               source                                               text  \\\n",
       "0  Twitter for iPhone  i think senator blumenthal should take a nice ...   \n",
       "1  Twitter for iPhone  how much longer will the failing nytimes with ...   \n",
       "2  Twitter for iPhone  the fake news media will not talk about the im...   \n",
       "4  Twitter for iPhone  on #purpleheartday💜i thank all the brave men a...   \n",
       "5  Twitter for iPhone  ...conquests how brave he was and it was all a...   \n",
       "\n",
       "            created_at favorite_count is_retweet        id_str  \n",
       "0  08-07-2017 20:48:54          61446      false  8.946617e+17  \n",
       "1  08-07-2017 20:39:46          42235      false  8.946594e+17  \n",
       "2  08-07-2017 20:15:18          45050      false  8.946532e+17  \n",
       "4  08-07-2017 18:03:42          48472      false  8.946201e+17  \n",
       "5  08-07-2017 12:01:20          59253      false  8.945289e+17  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('trump.csv')\n",
    "df = df[df.is_retweet=='false']\n",
    "df.text = df.text.str.replace(r'http[\\w:/\\.]+','') # remove urls\n",
    "df.text = df.text.str.lower()\n",
    "df = df[[len(t)<180 for t in df.text.values]]\n",
    "df = df[[len(t)>50 for t in df.text.values]]\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(23938, 6)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remove emojis, flags etc from tweets. Also notice how I have used `[::-1]` to indicate that I want the tweets in chrnological order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['be sure to tune in and watch donald trump on late night with david letterman as he presents the top ten list tonight!',\n",
       " 'donald trump will be appearing on the view tomorrow morning to discuss celebrity apprentice and his new book think like a champion!',\n",
       " 'donald trump reads top ten financial tips on late show with david letterman:  - very funny!',\n",
       " 'new blog post: celebrity apprentice finale and lessons learned along the way: ',\n",
       " 'my persona will never be that of a wallflower - i’d rather build walls than cling to them --donald j. trump']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# remove emojis and flags\n",
    "emoji_pattern = re.compile(\"[\"\n",
    "        u\"\\U0001F600-\\U0001F64F\"  # emoticons\n",
    "        u\"\\U0001F300-\\U0001F5FF\"  # symbols & pictographs\n",
    "        u\"\\U0001F680-\\U0001F6FF\"  # transport & map symbols\n",
    "        u\"\\U0001F1E0-\\U0001F1FF\"  # flags (iOS)\n",
    "                           \"]+\", flags=re.UNICODE)\n",
    "trump_tweets = [emoji_pattern.sub(r' ', text) for text in df.text.values[::-1]]\n",
    "trump_tweets[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a dictionary to convert letters to numbers and vice versa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_tweets = ''.join(trump_tweets)\n",
    "char2int = dict(zip(set(all_tweets), range(len(set(all_tweets)))))\n",
    "char2int['<END>'] = len(char2int)\n",
    "char2int['<GO>'] = len(char2int)\n",
    "char2int['<PAD>'] = len(char2int)\n",
    "int2char = dict(zip(char2int.values(), char2int.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "text_num = [[char2int['<GO>']]+[char2int[c] for c in tweet]+ [char2int['<END>']] for tweet in trump_tweets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD8CAYAAACRkhiPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD01JREFUeJzt3X+s3Xddx/Hni1bmwCCbLXV0xVZTTDoSB9Q5fwacYcMZ\nO/8hJSolTmpkEjVE0kEiatKkoEIkcdPKBkVhS4PgmsCU0RiJf4xxwcHWjWaVdqy1W4tEh5oMOt7+\ncT5jh5t7e3+c23vuuZ/nI/nmfM/n+/2e+/6kvfd1Pp/v93xPqgpJUp+eM+4CJEnjYwhIUscMAUnq\nmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOrZ23AXMZd26dbV58+ZxlyFJE+Xzn//816pq/Vz7\nrfgQ2Lx5M1NTU+MuQ5ImSpJH57Of00GS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCk\njhkCktSxFf+JYUkrx+Y9n5ix/cS+65e5Ei0VRwKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aA\nJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhS\nxwwBSeqYISBJHTMEJKljc4ZAkk1J/jnJQ0mOJPnd1n5pknuSPNIeLxk65uYkx5IcTXLtUPsrkzzQ\ntr0vSS5MtyRJ8zGfkcA54K1VtQ24GrgpyTZgD3C4qrYCh9tz2radwBXAdcAtSda017oVeBOwtS3X\nLWFfJEkLNGcIVNXpqvpCW/8G8DCwEdgBHGi7HQBuaOs7gDur6qmqOg4cA65Kchnwgqq6t6oK+NDQ\nMZKkMVjQOYEkm4GXA58FNlTV6bbpcWBDW98IPDZ02MnWtrGtT2+f6efsTjKVZOrs2bMLKVGStADz\nDoEk3wf8PfB7VfXk8Lb2zr6Wqqiq2l9V26tq+/r165fqZSVJ08wrBJJ8D4MA+HBVfaw1P9GmeGiP\nZ1r7KWDT0OGXt7ZTbX16uyRpTOZzdVCA24CHq+o9Q5sOAbva+i7grqH2nUkuSrKFwQng+9rU0ZNJ\nrm6v+YahYyRJY7B2Hvv8NPDrwANJ7m9tbwf2AQeT3Ag8CrwOoKqOJDkIPMTgyqKbqurpdtybgQ8C\nFwN3t0WSNCZzhkBV/Ssw2/X818xyzF5g7wztU8DLFlKgJOnC8RPDktQxQ0CSOmYISFLHDAFJ6pgh\nIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS\n1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjq0ddwGSJt/mPZ+Ysf3EvuuXuRItlCMB\nSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY3OGQJLbk5xJ\n8uBQ2x8lOZXk/rb84tC2m5McS3I0ybVD7a9M8kDb9r4kWfruSJIWYj4jgQ8C183Q/t6qurItnwRI\nsg3YCVzRjrklyZq2/63Am4CtbZnpNSVJy2jOEKiqzwBfn+fr7QDurKqnquo4cAy4KsllwAuq6t6q\nKuBDwA2LLVqStDRGOSfwliRfatNFl7S2jcBjQ/ucbG0b2/r09hkl2Z1kKsnU2bNnRyhRknQ+iw2B\nW4EfBq4ETgN/vmQVAVW1v6q2V9X29evXL+VLS5KGLCoEquqJqnq6qr4N/A1wVdt0Ctg0tOvlre1U\nW5/eLkkao0WFQJvjf8avAM9cOXQI2JnkoiRbGJwAvq+qTgNPJrm6XRX0BuCuEeqWJC2BOb9eMskd\nwKuAdUlOAu8EXpXkSqCAE8BvAVTVkSQHgYeAc8BNVfV0e6k3M7jS6GLg7rZIksZozhCoqtfP0Hzb\nefbfC+ydoX0KeNmCqpMkXVB+YliSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNA\nkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOzfl9ApKWxuY9n5ix/cS+65e5EulZjgQkqWOG\ngCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOee8gqWPez0iOBCSp\nY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pifE5AWaTVfYz9b37T6zDkSSHJ7kjNJHhxquzTJ\nPUkeaY+XDG27OcmxJEeTXDvU/sokD7Rt70uSpe+OJGkh5jMd9EHgumlte4DDVbUVONyek2QbsBO4\noh1zS5I17ZhbgTcBW9sy/TUlSctszhCoqs8AX5/WvAM40NYPADcMtd9ZVU9V1XHgGHBVksuAF1TV\nvVVVwIeGjpEkjcliTwxvqKrTbf1xYENb3wg8NrTfyda2sa1Pb5ckjdHIVwe1d/a1BLV8R5LdSaaS\nTJ09e3YpX1qSNGSxIfBEm+KhPZ5p7aeATUP7Xd7aTrX16e0zqqr9VbW9qravX79+kSVKkuay2BA4\nBOxq67uAu4badya5KMkWBieA72tTR08mubpdFfSGoWMkSWMy5+cEktwBvApYl+Qk8E5gH3AwyY3A\no8DrAKrqSJKDwEPAOeCmqnq6vdSbGVxpdDFwd1skSWM0ZwhU1etn2XTNLPvvBfbO0D4FvGxB1UmS\nLihvGyFJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmN8sJi0xv5VLk8SRgCR1zBCQ\npI4ZApLUMUNAkjrmiWFJF8xsJ8lP7Lt+mSvRbBwJSFLHHAlIq4jvvLVQjgQkqWOGgCR1zBCQpI55\nTkBaoc53+wnn+LVUDAF1x5On0rOcDpKkjhkCktQxp4M08ZzekRbPENCK4slQaXk5HSRJHXMkII2Z\n30SmcTIEdEE5Xy+tbE4HSVLHDAFJ6pjTQdIcnLPXauZIQJI65khA6oCjGc3GkYAkdcwQkKSOGQKS\n1LGRQiDJiSQPJLk/yVRruzTJPUkeaY+XDO1/c5JjSY4muXbU4iVJo1mKkcCrq+rKqtrenu8BDlfV\nVuBwe06SbcBO4ArgOuCWJGuW4OdLkhbpQkwH7QAOtPUDwA1D7XdW1VNVdRw4Blx1AX6+JGmeRr1E\ntIBPJ3ka+Ouq2g9sqKrTbfvjwIa2vhG4d+jYk61NY7TQe/us5nsBeRmlejRqCPxMVZ1K8iLgniRf\nHt5YVZWkFvqiSXYDuwFe8pKXjFiiJGk2I00HVdWp9ngG+DiD6Z0nklwG0B7PtN1PAZuGDr+8tc30\nuvurantVbV+/fv0oJUqSzmPRI4EkzweeU1XfaOuvAf4EOATsAva1x7vaIYeAjyR5D/BiYCtw3wi1\nS+fl9I40t1GmgzYAH0/yzOt8pKr+McnngINJbgQeBV4HUFVHkhwEHgLOATdV1dMjVS9JGsmiQ6Cq\nvgL82Azt/wlcM8sxe4G9i/2ZPVrNJ2K1eI5ytFS8gdwqY2hIWghvGyFJHTMEJKljTgdpLJzTllYG\nQ2AZne8Pn3P2ksbBENDEcPQgLT1DQEvCP9DSZDIERuDlmJImnVcHSVLHDAFJ6pjTQZqRc/xSHwyB\nTvhHXdJMnA6SpI45EphQvrOXtBQMgSFe8impN04HSVLHHAlcAIuZqnF6R9I4OBKQpI45EpgH36VL\nWq0cCUhSxwwBSeqYISBJHevynIBz/JI00GUISBovP5i5cjgdJEkdMwQkqWOGgCR1zBCQpI4ZApLU\nMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOLfsN5JJcB/wFsAZ4f1XtW+4aJK1M3lhu\n+S3rSCDJGuAvgdcC24DXJ9m2nDVIkp613NNBVwHHquorVfVN4E5gxzLXIElqlns6aCPw2NDzk8BP\nXKgf5pfHSKuD00QXzor8Upkku4Hd7en/JDk6znoWaB3wtXEXMaLV0AdYHf2wD+eRd12IV53RJP47\n/NB8dlruEDgFbBp6fnlr+y5VtR/Yv1xFLaUkU1W1fdx1jGI19AFWRz/sw8qwGvowm+U+J/A5YGuS\nLUmeC+wEDi1zDZKkZllHAlV1LsnvAP/E4BLR26vqyHLWIEl61rKfE6iqTwKfXO6fu4wmchprmtXQ\nB1gd/bAPK8Nq6MOMUlXjrkGSNCbeNkKSOmYIjCjJC5N8NMmXkzyc5CeTXJrkniSPtMdLxl3n+ST5\n/SRHkjyY5I4k37vS+5Dk9iRnkjw41DZrzUluTnIsydEk146n6u82Sx/+tP1f+lKSjyd54dC2iejD\n0La3Jqkk64baJqYPSd7S/i2OJHn3UPuK68NIqsplhAU4APxmW38u8ELg3cCe1rYHeNe46zxP/RuB\n48DF7flB4I0rvQ/AzwGvAB4capuxZga3KPkicBGwBfh3YM0K7cNrgLVt/V2T2IfWvonBBSCPAusm\nrQ/Aq4FPAxe15y9ayX0YZXEkMIIk38/gP9BtAFX1zar6Lwa3wjjQdjsA3DCeCudtLXBxkrXA84D/\nYIX3oao+A3x9WvNsNe8A7qyqp6rqOHCMwS1MxmqmPlTVp6rqXHt6L4PP0sAE9aF5L/A2YPik4yT1\n4beBfVX1VNvnTGtfkX0YhSEwmi3AWeADSf4tyfuTPB/YUFWn2z6PAxvGVuEcquoU8GfAV4HTwH9X\n1aeYoD4Mma3mmW5XsnE5C1uk3wDubusT04ckO4BTVfXFaZsmpg/AS4GfTfLZJP+S5Mdb+yT1YV4M\ngdGsZTCMvLWqXg78L4NpiO+owRhyxV6C1ebNdzAItBcDz0/ya8P7rPQ+zGQSax6W5B3AOeDD465l\nIZI8D3g78IfjrmVEa4FLgauBPwAOJsl4S7owDIHRnAROVtVn2/OPMgiFJ5JcBtAez8xy/ErwC8Dx\nqjpbVd8CPgb8FJPVh2fMVvO8bleyUiR5I/BLwK+2MIPJ6cOPMHhD8cUkJxjU+YUkP8jk9AEGv9sf\nq4H7gG8zuH/QJPVhXgyBEVTV48BjSX60NV0DPMTgVhi7Wtsu4K4xlDdfXwWuTvK89k7nGuBhJqsP\nz5it5kPAziQXJdkCbAXuG0N9c2pfuvQ24Jer6v+GNk1EH6rqgap6UVVtrqrNDP6YvqL9rkxEH5p/\nYHBymCQvZXDRx9eYrD7Mz7jPTE/6AlwJTAFfYvAf5xLgB4DDwCMMrjC4dNx1ztGHPwa+DDwI/C2D\nKx9WdB+AOxicw/gWgz80N56vZuAdDK7kOAq8dtz1n6cPxxjMOd/flr+atD5M236CdnXQJPWBwR/9\nv2u/E18Afn4l92GUxU8MS1LHnA6SpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdez/\nAaVklWe744EkAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f41169b1ef0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([len(t) for t in trump_tweets],50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Concatenate all the tweets\n",
    "int_text = []\n",
    "for t in text_num:\n",
    "    int_text += t\n",
    "\n",
    "len_vocab = len(char2int)\n",
    "sentence_len = 40\n",
    "# n_chars = len(text_num)//sentence_len*sentence_len\n",
    "num_chunks = len(text_num)-sentence_len\n",
    "\n",
    "def get_batches(int_text, batch_size, seq_length):\n",
    "    \"\"\"\n",
    "    Return batches of input and target\n",
    "    :param int_text: Text with the words replaced by their ids\n",
    "    :param batch_size: The size of batch\n",
    "    :param seq_length: The length of sequence\n",
    "    :return: Batches as a Numpy array\n",
    "    \"\"\"\n",
    "    \n",
    "    slice_size = batch_size * seq_length\n",
    "    n_batches = len(int_text) // slice_size\n",
    "    x = int_text[: n_batches*slice_size]\n",
    "    y = int_text[1: n_batches*slice_size + 1]\n",
    "\n",
    "    x = np.split(np.reshape(x,(batch_size,-1)),n_batches,1)\n",
    "    y = np.split(np.reshape(y,(batch_size,-1)),n_batches,1)\n",
    "    \n",
    "    x = np.vstack(x)\n",
    "    y = np.vstack(y)\n",
    "    y = y.reshape(y.shape+(1,))\n",
    "    return x, y\n",
    "\n",
    "batch_size = 128\n",
    "x, y = get_batches(int_text, batch_size, sentence_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice what the `get_batches` function looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 0,  1,  2,  3],\n",
       "        [ 8,  9, 10, 11],\n",
       "        [ 4,  5,  6,  7],\n",
       "        [12, 13, 14, 15]]), array([[ 1,  2,  3,  4],\n",
       "        [ 9, 10, 11, 12],\n",
       "        [ 5,  6,  7,  8],\n",
       "        [13, 14, 15, 16]]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_batches(np.arange(20), 2,4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Many to Many LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (128, None, 64)           8512      \n",
      "_________________________________________________________________\n",
      "lstm_1 (LSTM)                (128, None, 64)           33024     \n",
      "_________________________________________________________________\n",
      "lstm_2 (LSTM)                (128, None, 64)           33024     \n",
      "_________________________________________________________________\n",
      "time_distributed_1 (TimeDist (128, None, 133)          8645      \n",
      "=================================================================\n",
      "Total params: 83,205.0\n",
      "Trainable params: 83,205\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "# TODO: \n",
    "# 1. Add an embedding layer\n",
    "# 2. Add a LSTM layer, set `return_sequences=True` and stateful = True\n",
    "# 3. Add another LSTM layer, set `return_sequences=True` and stateful = True\n",
    "# 4. Add a `TimeDistributed(Dense(...)) layer (connected to how many outputs? What is the activation?)\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pay special attention to how the probabilites are taken. p is of shape `(1, sequence_len, len(char2int))` where len(char2int) is the number of available characters. The 1 is there because we are only predicting one feature, `y`. We are only concerned about the last prediction probability of the sequence. This is due to the fact that all other letters have already been appended. Hence we predict a letter from the distribution `p[0][-1]`.\n",
    "\n",
    "Why did we keep appending to the sequence and predicting? Why not use simply the last letter. If we were to do this, we would lose information that comes from the previous letter via the hidden state and cell memory. Keep in mind that each LSTM unit has 3 inputs, the x, the hidden state, and the cell memory. \n",
    "\n",
    "Also important to notice that the Cell Memory is not used in connecting to the Dense layer, only the hidden state.\n",
    "\n",
    "### Stateful training:\n",
    "\n",
    "What happens when `stateful=True` is that the last cell memory state computed at the i-th example in the n-th batch gets passed on to i-th sample in the n+1-th batch. This is one way of _seeing_ patterns beyond the sentence length specified. Which in this case is 40.\n",
    "\n",
    "#### Note\n",
    "1. Really important that when I `fit` the model I set `shuffle=False` when training stateful models.\n",
    "2. I need to copy the weights to a non-stateful model because the original only takes `batch_size` inputs at a time. Whereas, I only want to predict one character at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model2 = Sequential()\n",
    "# TODO\n",
    "# Create the same model architecture as above, however, no need to compile or create the summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>teradisid @realdonaldtrump i!.<END><GO>shaoning!<END><GO>dong drindter town be pabitate.<END><GO>americas you honely thi\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 90s - loss: 1.8014    \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 107s - loss: 1.7881   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 109s - loss: 1.7763   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.7657   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 105s - loss: 1.7560   - ETA: 3s\n",
      "<GO>ohamerible anergigh manigion (316      9 thinks? fambew losonised made up plcind.  the beaded?thing\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 107s - loss: 1.7471   - ETA: 1s - loss\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 105s - loss: 1.7389   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 102s - loss: 1.7314   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 98s - loss: 1.7241    \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 99s - loss: 1.7172    \n",
      "<GO>fund having abifationally diand have as prokervia @lanneals<END><GO>say<END><GO>musts country on my two amerhiera\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 98s - loss: 1.7106    \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 113s - loss: 1.7043   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 113s - loss: 1.6985   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 112s - loss: 1.6930   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 116s - loss: 1.6879   \n",
      "<GO>je @realdonaldtrump segew interview- see. years <END><GO>@foxer:  @an79: #patc @unclifferi how the (contra\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 114s - loss: 1.6831   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 115s - loss: 1.6784   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.6739   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.6695   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.6652   \n",
      "<GO>woues donald trump yourspire roling more cothanks?<END><GO>“@courst_cmingc  in than same than courterful s\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 105s - loss: 1.6609   - ETA: 0s - loss: 1.\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.6567   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 106s - loss: 1.6528   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 116s - loss: 1.6490   \n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 120s - loss: 1.6455   \n",
      "<GO>10 mittry was speed focus to u wornd &amp; the provings losbacces has carolina<END><GO>for the great the c\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "66560/66560 [==============================] - 116s - loss: 1.6421   \n",
      "Epoch 1/1\n",
      " 6528/66560 [=>............................] - ETA: 95s - loss: 1.6425"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-72-3f7e4b2e195c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'='\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m!=\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_states\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/models.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)\u001b[0m\n\u001b[1;32m    843\u001b[0m                               \u001b[0mclass_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclass_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    844\u001b[0m                               \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 845\u001b[0;31m                               initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m    846\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    847\u001b[0m     def evaluate(self, x, y, batch_size=32, verbose=1,\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, **kwargs)\u001b[0m\n\u001b[1;32m   1483\u001b[0m                               \u001b[0mval_f\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_ins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_ins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1484\u001b[0m                               \u001b[0mcallback_metrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallback_metrics\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1485\u001b[0;31m                               initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m   1486\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1487\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_fit_loop\u001b[0;34m(self, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch)\u001b[0m\n\u001b[1;32m   1138\u001b[0m                 \u001b[0mbatch_logs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'size'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1139\u001b[0m                 \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1140\u001b[0;31m                 \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1141\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1142\u001b[0m                     \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   2071\u001b[0m         \u001b[0msession\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2072\u001b[0m         updated = session.run(self.outputs + [self.updates_op],\n\u001b[0;32m-> 2073\u001b[0;31m                               feed_dict=feed_dict)\n\u001b[0m\u001b[1;32m   2074\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mupdated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2075\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    776\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    777\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 778\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    779\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    780\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    980\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    981\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 982\u001b[0;31m                              feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m    983\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    984\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1030\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1031\u001b[0m       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1032\u001b[0;31m                            target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m   1033\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1034\u001b[0m       return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1037\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1038\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1040\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1041\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1019\u001b[0m         return tf_session.TF_Run(session, options,\n\u001b[1;32m   1020\u001b[0m                                  \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1021\u001b[0;31m                                  status, run_metadata)\n\u001b[0m\u001b[1;32m   1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1023\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "n_epochs = 50\n",
    "for i in range(n_epochs+1):\n",
    "    if i%5==0:\n",
    "        sentence = []\n",
    "        letter = [char2int['<GO>']] #choose a random letter\n",
    "        for i in range(100):\n",
    "            sentence.append(int2char[letter[-1]])\n",
    "            model2.set_weights(model.get_weights())\n",
    "            p = model2.predict(np.array(letter)[None,:])\n",
    "            letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "        print(''.join(sentence))\n",
    "        print('='*100)\n",
    "    if i!=n_epochs:\n",
    "        model.fit(x,y, batch_size=batch_size, epochs=1, shuffle=False)\n",
    "        model.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('model_struct.json','w') as f:\n",
    "    f.write(model.to_json())\n",
    "model.save_weights('model_weights.h5')\n",
    "model.save('model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# if not 'model' in vars():\n",
    "# #     model = load_model('model.h5') # This doesn't seem to work for some odd reason\n",
    "#     with open('model_struct.json','r') as f:\n",
    "#         model = model_from_json(f.read())\n",
    "#     model.load_weights('model_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_4 (Embedding)      (None, None, 64)          8512      \n",
      "_________________________________________________________________\n",
      "lstm_7 (LSTM)                (None, None, 64)          33024     \n",
      "_________________________________________________________________\n",
      "time_distributed_4 (TimeDist (None, None, 133)         8645      \n",
      "=================================================================\n",
      "Total params: 50,181.0\n",
      "Trainable params: 50,181\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(len_vocab, 64)) # , batch_size=batch_size\n",
    "model.add(LSTM(64, return_sequences=True)) # , stateful=True\n",
    "model.add(TimeDistributed(Dense(len_vocab, activation='softmax')))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()\n",
    "model.load_weights('model_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>❤ q✨1iqáiq​it iw6♥❤/wi☁6 iq✨i‏☁q✨/d‼i~wqá6w☁(i​☁´t6i​/ q✨1i ´♥iá/❤\\iw♥✨☝iq✨​áw♥‏áq´✨i\"´wiá☁q​iá´i16~(i/✨☝i✨´ái☺❤/♥✨i☺❤/✨​i/☝☉di16w/❤​i\"´☝ig✨61w6ái‏/✨\n",
      "====================================================================================================\n",
      "<GO>it/\\6i ´♥i☁6iá☁6iq✨á6w6☝=​i/✨☝i~q❤❤it/‏\\/✨☝♡{i~q❤w6​i/✨☝i/i☺6w​´✨i☁/☉6i♥✨\"áiáqt6☝☞i‏w´´\\6w(ii❌áw♥t☺‏´❤6‏á (i‼☁/☺☺6✨​iq​i/9​´w iá´i‏w6/á6i❤´´\\​i☝´{igw\n",
      "====================================================================================================\n",
      "<GO>~☁ i☝´✨=ái96\"´w6iá☁6i♥☺iá´i16ááq✨1i/✨‏t​‏♥á(igá6☉6✨☝\"´~á❤☝1☁á​iq​iq✨i/‏‏6​​i~☁ i‏☁/❤❤6/á​i‏/t☺/q1✨=ái☺´❤❤(i/i1´☉6w✨61/ i​☁♥ái´♥wi☺´❤qá´✨i~3)♥✨☝❤i~☁ i\n",
      "====================================================================================================\n",
      "<GO>g/á‏❤♥9​​(i96‏´♥w✨6wi‏´ttiá´i»☺6ww (i/​i‏6áá6w(i/i☝/☉6☝i9w/q✨​á6wi ´♥i♥✨qá (i❌1w6✨✨á6w´​tq​áq´✨i~6/☺6‏´✨‏6☝i​/q☝i ´♥iá☁/✨iá☁6=​i✨/áq´✨​(i​/ ​iá´iá☁6i\n",
      "====================================================================================================\n",
      "<GO>gá☁6á6wt♥☺6iq✨i6☝q✨6​(iá☁6 i/✨☝i❤´´\\(i‏❤/✨it/☝6i/✨ ´✨6i«-s-i6%☺/✨☝ i\"´wi´♥wi❌´♥á/1/✨i☁/✨☝​i\"´wi/i❤´✨1(i‏´♥✨áw (igá´tt6☝/16(i6/​ #i/✨i6%6‏♥áq☉6i☺❤/‏6​\n",
      "====================================================================================================\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-73ecae28f52a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0msentence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'<END>'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m             \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m         \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mletter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m         \u001b[0mletter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchar2int\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/models.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose)\u001b[0m\n\u001b[1;32m    889\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    890\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 891\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    892\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    893\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mpredict_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose)\u001b[0m\n\u001b[1;32m   1570\u001b[0m         \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1571\u001b[0m         return self._predict_loop(f, ins,\n\u001b[0;32m-> 1572\u001b[0;31m                                   batch_size=batch_size, verbose=verbose)\n\u001b[0m\u001b[1;32m   1573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1574\u001b[0m     def train_on_batch(self, x, y,\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_predict_loop\u001b[0;34m(self, f, ins, batch_size, verbose)\u001b[0m\n\u001b[1;32m   1200\u001b[0m                 \u001b[0mins_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_slice_arrays\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1202\u001b[0;31m             \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1203\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1204\u001b[0m                 \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   2071\u001b[0m         \u001b[0msession\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2072\u001b[0m         updated = session.run(self.outputs + [self.updates_op],\n\u001b[0;32m-> 2073\u001b[0;31m                               feed_dict=feed_dict)\n\u001b[0m\u001b[1;32m   2074\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mupdated\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2075\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    776\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    777\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 778\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    779\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    780\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    980\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    981\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m--> 982\u001b[0;31m                              feed_dict_string, options, run_metadata)\n\u001b[0m\u001b[1;32m    983\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    984\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1030\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1031\u001b[0m       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,\n\u001b[0;32m-> 1032\u001b[0;31m                            target_list, options, run_metadata)\n\u001b[0m\u001b[1;32m   1033\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1034\u001b[0m       return self._do_call(_prun_fn, self._session, handle, feed_dict,\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1037\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1038\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1040\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1041\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1019\u001b[0m         return tf_session.TF_Run(session, options,\n\u001b[1;32m   1020\u001b[0m                                  \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1021\u001b[0;31m                                  status, run_metadata)\n\u001b[0m\u001b[1;32m   1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1023\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for j in range(10):\n",
    "    sentence = []\n",
    "    letter = [char2int['<GO>']] #choose a random letter\n",
    "    for i in range(150):\n",
    "        sentence.append(int2char[letter[-1]])\n",
    "        if sentence[-1]=='<END>':\n",
    "            break\n",
    "        p = model.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "\n",
    "    print(''.join(sentence))\n",
    "    print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}