{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM (Long Short Term Memory)\n",
    "\n",
    "There is a branch of Deep Learning that is dedicated to processing time series. These deep Nets are **Recursive Neural Nets (RNNs)**. LSTMs are one of the few types of RNNs that are available. Gated Recurent Units (GRUs) are the other type of popular RNNs.\n",
    "\n",
    "This is an illustration from http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (A highly recommended read)\n",
    "\n",
    "![RNNs](./images/RNN-unrolled.png)\n",
    "\n",
    "Pros:\n",
    "- Really powerful pattern recognition system for time series\n",
    "\n",
    "Cons:\n",
    "- Cannot deal with missing time steps.\n",
    "- Time steps must be discretised and not continuous.\n",
    "![trump](./images/trump.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import re\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization, LSTM, Embedding, TimeDistributed\n",
    "from keras.models import load_model, model_from_json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>text</th>\n",
       "      <th>created_at</th>\n",
       "      <th>favorite_count</th>\n",
       "      <th>is_retweet</th>\n",
       "      <th>id_str</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>i think senator blumenthal should take a nice ...</td>\n",
       "      <td>08-07-2017 20:48:54</td>\n",
       "      <td>61446</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946617e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>how much longer will the failing nytimes with ...</td>\n",
       "      <td>08-07-2017 20:39:46</td>\n",
       "      <td>42235</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946594e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>the fake news media will not talk about the im...</td>\n",
       "      <td>08-07-2017 20:15:18</td>\n",
       "      <td>45050</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946532e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>on #purpleheartday💜i thank all the brave men a...</td>\n",
       "      <td>08-07-2017 18:03:42</td>\n",
       "      <td>48472</td>\n",
       "      <td>false</td>\n",
       "      <td>8.946201e+17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Twitter for iPhone</td>\n",
       "      <td>...conquests how brave he was and it was all a...</td>\n",
       "      <td>08-07-2017 12:01:20</td>\n",
       "      <td>59253</td>\n",
       "      <td>false</td>\n",
       "      <td>8.945289e+17</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               source                                               text  \\\n",
       "0  Twitter for iPhone  i think senator blumenthal should take a nice ...   \n",
       "1  Twitter for iPhone  how much longer will the failing nytimes with ...   \n",
       "2  Twitter for iPhone  the fake news media will not talk about the im...   \n",
       "4  Twitter for iPhone  on #purpleheartday💜i thank all the brave men a...   \n",
       "5  Twitter for iPhone  ...conquests how brave he was and it was all a...   \n",
       "\n",
       "            created_at favorite_count is_retweet        id_str  \n",
       "0  08-07-2017 20:48:54          61446      false  8.946617e+17  \n",
       "1  08-07-2017 20:39:46          42235      false  8.946594e+17  \n",
       "2  08-07-2017 20:15:18          45050      false  8.946532e+17  \n",
       "4  08-07-2017 18:03:42          48472      false  8.946201e+17  \n",
       "5  08-07-2017 12:01:20          59253      false  8.945289e+17  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/trump.csv') # might need to change location if on Floydhub\n",
    "df = df[df.is_retweet=='false']\n",
    "df.text = df.text.str.lower()\n",
    "df.text = df.text.str.replace(r'http[\\w:/\\.]+','') # remove urls\n",
    "df.text = df.text.str.replace(r'[^!\\'\"#$%&\\()*+,-./:;<=>?@_’`{|}~\\w\\s]',' ') #remove everything but characters and punctuation\n",
    "df.text = df.text.str.replace(r'\\s\\s+',' ') #replace multple white space with a single one\n",
    "df = df[[len(t)<180 for t in df.text.values]]\n",
    "df = df[[len(t)>50 for t in df.text.values]]\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(23938, 6)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trump_tweets = [text for text in df.text.values[::-1]]\n",
    "trump_tweets[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a dictionary to convert letters to numbers and vice versa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_tweets = ''.join(trump_tweets)\n",
    "char2int = dict(zip(set(all_tweets), range(len(set(all_tweets)))))\n",
    "char2int['<END>'] = len(char2int)\n",
    "char2int['<GO>'] = len(char2int)\n",
    "char2int['<PAD>'] = len(char2int)\n",
    "int2char = dict(zip(char2int.values(), char2int.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "text_num = [[char2int['<GO>']]+[char2int[c] for c in tweet]+ [char2int['<END>']] for tweet in trump_tweets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD8CAYAAACRkhiPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD01JREFUeJzt3X+s3Xddx/Hni1bmwCCbLXV0xVZTTDoSB9Q5fwacYcMZ\nO/8hJSolTmpkEjVE0kEiatKkoEIkcdPKBkVhS4PgmsCU0RiJf4xxwcHWjWaVdqy1W4tEh5oMOt7+\ncT5jh5t7e3+c23vuuZ/nI/nmfM/n+/2e+/6kvfd1Pp/v93xPqgpJUp+eM+4CJEnjYwhIUscMAUnq\nmCEgSR0zBCSpY4aAJHXMEJCkjhkCktQxQ0CSOrZ23AXMZd26dbV58+ZxlyFJE+Xzn//816pq/Vz7\nrfgQ2Lx5M1NTU+MuQ5ImSpJH57Of00GS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCk\njhkCktSxFf+JYUkrx+Y9n5ix/cS+65e5Ei0VRwKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aA\nJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhS\nxwwBSeqYISBJHTMEJKljc4ZAkk1J/jnJQ0mOJPnd1n5pknuSPNIeLxk65uYkx5IcTXLtUPsrkzzQ\ntr0vSS5MtyRJ8zGfkcA54K1VtQ24GrgpyTZgD3C4qrYCh9tz2radwBXAdcAtSda017oVeBOwtS3X\nLWFfJEkLNGcIVNXpqvpCW/8G8DCwEdgBHGi7HQBuaOs7gDur6qmqOg4cA65Kchnwgqq6t6oK+NDQ\nMZKkMVjQOYEkm4GXA58FNlTV6bbpcWBDW98IPDZ02MnWtrGtT2+f6efsTjKVZOrs2bMLKVGStADz\nDoEk3wf8PfB7VfXk8Lb2zr6Wqqiq2l9V26tq+/r165fqZSVJ08wrBJJ8D4MA+HBVfaw1P9GmeGiP\nZ1r7KWDT0OGXt7ZTbX16uyRpTOZzdVCA24CHq+o9Q5sOAbva+i7grqH2nUkuSrKFwQng+9rU0ZNJ\nrm6v+YahYyRJY7B2Hvv8NPDrwANJ7m9tbwf2AQeT3Ag8CrwOoKqOJDkIPMTgyqKbqurpdtybgQ8C\nFwN3t0WSNCZzhkBV/Ssw2/X818xyzF5g7wztU8DLFlKgJOnC8RPDktQxQ0CSOmYISFLHDAFJ6pgh\nIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS\n1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY4aAJHXMEJCkjq0ddwGSJt/mPZ+Ysf3EvuuXuRItlCMB\nSeqYISBJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmCEgSR0zBCSpY3OGQJLbk5xJ\n8uBQ2x8lOZXk/rb84tC2m5McS3I0ybVD7a9M8kDb9r4kWfruSJIWYj4jgQ8C183Q/t6qurItnwRI\nsg3YCVzRjrklyZq2/63Am4CtbZnpNSVJy2jOEKiqzwBfn+fr7QDurKqnquo4cAy4KsllwAuq6t6q\nKuBDwA2LLVqStDRGOSfwliRfatNFl7S2jcBjQ/ucbG0b2/r09hkl2Z1kKsnU2bNnRyhRknQ+iw2B\nW4EfBq4ETgN/vmQVAVW1v6q2V9X29evXL+VLS5KGLCoEquqJqnq6qr4N/A1wVdt0Ctg0tOvlre1U\nW5/eLkkao0WFQJvjf8avAM9cOXQI2JnkoiRbGJwAvq+qTgNPJrm6XRX0BuCuEeqWJC2BOb9eMskd\nwKuAdUlOAu8EXpXkSqCAE8BvAVTVkSQHgYeAc8BNVfV0e6k3M7jS6GLg7rZIksZozhCoqtfP0Hzb\nefbfC+ydoX0KeNmCqpMkXVB+YliSOmYISFLHDAFJ6pghIEkdMwQkqWOGgCR1zBCQpI4ZApLUMUNA\nkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOzfl9ApKWxuY9n5ix/cS+65e5EulZjgQkqWOG\ngCR1zBCQpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOee8gqWPez0iOBCSp\nY4aAJHXMEJCkjhkCktQxQ0CSOmYISFLHDAFJ6pifE5AWaTVfYz9b37T6zDkSSHJ7kjNJHhxquzTJ\nPUkeaY+XDG27OcmxJEeTXDvU/sokD7Rt70uSpe+OJGkh5jMd9EHgumlte4DDVbUVONyek2QbsBO4\noh1zS5I17ZhbgTcBW9sy/TUlSctszhCoqs8AX5/WvAM40NYPADcMtd9ZVU9V1XHgGHBVksuAF1TV\nvVVVwIeGjpEkjcliTwxvqKrTbf1xYENb3wg8NrTfyda2sa1Pb5ckjdHIVwe1d/a1BLV8R5LdSaaS\nTJ09e3YpX1qSNGSxIfBEm+KhPZ5p7aeATUP7Xd7aTrX16e0zqqr9VbW9qravX79+kSVKkuay2BA4\nBOxq67uAu4badya5KMkWBieA72tTR08mubpdFfSGoWMkSWMy5+cEktwBvApYl+Qk8E5gH3AwyY3A\no8DrAKrqSJKDwEPAOeCmqnq6vdSbGVxpdDFwd1skSWM0ZwhU1etn2XTNLPvvBfbO0D4FvGxB1UmS\nLihvGyFJHTMEJKljhoAkdcwQkKSOGQKS1DFDQJI6ZghIUscMAUnqmN8sJi0xv5VLk8SRgCR1zBCQ\npI4ZApLUMUNAkjrmiWFJF8xsJ8lP7Lt+mSvRbBwJSFLHHAlIq4jvvLVQjgQkqWOGgCR1zBCQpI55\nTkBaoc53+wnn+LVUDAF1x5On0rOcDpKkjhkCktQxp4M08ZzekRbPENCK4slQaXk5HSRJHXMkII2Z\n30SmcTIEdEE5Xy+tbE4HSVLHDAFJ6pjTQdIcnLPXauZIQJI65khA6oCjGc3GkYAkdcwQkKSOGQKS\n1LGRQiDJiSQPJLk/yVRruzTJPUkeaY+XDO1/c5JjSY4muXbU4iVJo1mKkcCrq+rKqtrenu8BDlfV\nVuBwe06SbcBO4ArgOuCWJGuW4OdLkhbpQkwH7QAOtPUDwA1D7XdW1VNVdRw4Blx1AX6+JGmeRr1E\ntIBPJ3ka+Ouq2g9sqKrTbfvjwIa2vhG4d+jYk61NY7TQe/us5nsBeRmlejRqCPxMVZ1K8iLgniRf\nHt5YVZWkFvqiSXYDuwFe8pKXjFiiJGk2I00HVdWp9ngG+DiD6Z0nklwG0B7PtN1PAZuGDr+8tc30\nuvurantVbV+/fv0oJUqSzmPRI4EkzweeU1XfaOuvAf4EOATsAva1x7vaIYeAjyR5D/BiYCtw3wi1\nS+fl9I40t1GmgzYAH0/yzOt8pKr+McnngINJbgQeBV4HUFVHkhwEHgLOATdV1dMjVS9JGsmiQ6Cq\nvgL82Azt/wlcM8sxe4G9i/2ZPVrNJ2K1eI5ytFS8gdwqY2hIWghvGyFJHTMEJKljTgdpLJzTllYG\nQ2AZne8Pn3P2ksbBENDEcPQgLT1DQEvCP9DSZDIERuDlmJImnVcHSVLHDAFJ6pjTQZqRc/xSHwyB\nTvhHXdJMnA6SpI45EphQvrOXtBQMgSFe8impN04HSVLHHAlcAIuZqnF6R9I4OBKQpI45EpgH36VL\nWq0cCUhSxwwBSeqYISBJHevynIBz/JI00GUISBovP5i5cjgdJEkdMwQkqWOGgCR1zBCQpI4ZApLU\nMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdcwQkKSOLfsN5JJcB/wFsAZ4f1XtW+4aJK1M3lhu\n+S3rSCDJGuAvgdcC24DXJ9m2nDVIkp613NNBVwHHquorVfVN4E5gxzLXIElqlns6aCPw2NDzk8BP\nXKgf5pfHSKuD00QXzor8Upkku4Hd7en/JDk6znoWaB3wtXEXMaLV0AdYHf2wD+eRd12IV53RJP47\n/NB8dlruEDgFbBp6fnlr+y5VtR/Yv1xFLaUkU1W1fdx1jGI19AFWRz/sw8qwGvowm+U+J/A5YGuS\nLUmeC+wEDi1zDZKkZllHAlV1LsnvAP/E4BLR26vqyHLWIEl61rKfE6iqTwKfXO6fu4wmchprmtXQ\nB1gd/bAPK8Nq6MOMUlXjrkGSNCbeNkKSOmYIjCjJC5N8NMmXkzyc5CeTXJrkniSPtMdLxl3n+ST5\n/SRHkjyY5I4k37vS+5Dk9iRnkjw41DZrzUluTnIsydEk146n6u82Sx/+tP1f+lKSjyd54dC2iejD\n0La3Jqkk64baJqYPSd7S/i2OJHn3UPuK68NIqsplhAU4APxmW38u8ELg3cCe1rYHeNe46zxP/RuB\n48DF7flB4I0rvQ/AzwGvAB4capuxZga3KPkicBGwBfh3YM0K7cNrgLVt/V2T2IfWvonBBSCPAusm\nrQ/Aq4FPAxe15y9ayX0YZXEkMIIk38/gP9BtAFX1zar6Lwa3wjjQdjsA3DCeCudtLXBxkrXA84D/\nYIX3oao+A3x9WvNsNe8A7qyqp6rqOHCMwS1MxmqmPlTVp6rqXHt6L4PP0sAE9aF5L/A2YPik4yT1\n4beBfVX1VNvnTGtfkX0YhSEwmi3AWeADSf4tyfuTPB/YUFWn2z6PAxvGVuEcquoU8GfAV4HTwH9X\n1aeYoD4Mma3mmW5XsnE5C1uk3wDubusT04ckO4BTVfXFaZsmpg/AS4GfTfLZJP+S5Mdb+yT1YV4M\ngdGsZTCMvLWqXg78L4NpiO+owRhyxV6C1ebNdzAItBcDz0/ya8P7rPQ+zGQSax6W5B3AOeDD465l\nIZI8D3g78IfjrmVEa4FLgauBPwAOJsl4S7owDIHRnAROVtVn2/OPMgiFJ5JcBtAez8xy/ErwC8Dx\nqjpbVd8CPgb8FJPVh2fMVvO8bleyUiR5I/BLwK+2MIPJ6cOPMHhD8cUkJxjU+YUkP8jk9AEGv9sf\nq4H7gG8zuH/QJPVhXgyBEVTV48BjSX60NV0DPMTgVhi7Wtsu4K4xlDdfXwWuTvK89k7nGuBhJqsP\nz5it5kPAziQXJdkCbAXuG0N9c2pfuvQ24Jer6v+GNk1EH6rqgap6UVVtrqrNDP6YvqL9rkxEH5p/\nYHBymCQvZXDRx9eYrD7Mz7jPTE/6AlwJTAFfYvAf5xLgB4DDwCMMrjC4dNx1ztGHPwa+DDwI/C2D\nKx9WdB+AOxicw/gWgz80N56vZuAdDK7kOAq8dtz1n6cPxxjMOd/flr+atD5M236CdnXQJPWBwR/9\nv2u/E18Afn4l92GUxU8MS1LHnA6SpI4ZApLUMUNAkjpmCEhSxwwBSeqYISBJHTMEJKljhoAkdez/\nAaVklWe744EkAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f89cacb1630>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([len(t) for t in trump_tweets],50)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_vocab = len(char2int)\n",
    "sentence_len = 40\n",
    "\n",
    "num_examples = 0\n",
    "for tweet in text_num:\n",
    "    num_examples += len(tweet)-sentence_len\n",
    "\n",
    "x = np.zeros((num_examples, sentence_len))\n",
    "y = np.zeros((num_examples, sentence_len))\n",
    "\n",
    "k = 0\n",
    "for tweet in text_num:\n",
    "    for i in range(len(tweet)-sentence_len):\n",
    "        # TODO: Get the data (x, y) points\n",
    "        # 1. `x` is supposed to be a sequence of sentence_len letters for that particular tweet\n",
    "        # 2. `y` is simply x offset by one letter (but still a sequence of sentence_len letters)\n",
    "        k += 1\n",
    "        \n",
    "y = y.reshape(y.shape+(1,)) # y needs to be of shape (N, sequence_len, num_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Many to Many LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_2 (Embedding)      (None, None, 64)          8512      \n",
      "_________________________________________________________________\n",
      "lstm_2 (LSTM)                (None, None, 64)          33024     \n",
      "_________________________________________________________________\n",
      "time_distributed_2 (TimeDist (None, None, 133)         8645      \n",
      "=================================================================\n",
      "Total params: 50,181.0\n",
      "Trainable params: 50,181\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "# TODO: create the model\n",
    "# 1. Have an embedding layer for each character\n",
    "# 2. Put them into an LSTM with say 64 hidden units, but return all hidden states (return_sequences=True)\n",
    "# 3. Take all those hidden states and connect them to a Dense layer\n",
    "#    - Syntax is TimeDistributed(Dense(....))\n",
    "#    - What are the number of units for Dense, what is the activation?\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pay special attention to how the probabilites are taken. p is of shape `(1, sequence_len, len(char2int))` where len(char2int) is the number of available characters. The 1 is there because we are only predicting one feature, `y`. We are only concerned about the last prediction probability of the sequence. This is due to the fact that all other letters have already been appended. Hence we predict a letter from the distribution `p[0][-1]`.\n",
    "\n",
    "Why did we keep appending to the sequence and predicting? Why not use simply the last letter. If we were to do this, we would lose information that comes from the previous letter via the hidden state and cell memory. Keep in mind that each LSTM unit has 3 inputs, the x, the hidden state, and the cell memory. \n",
    "\n",
    "Also important to notice that the Cell Memory is not used in connecting to the Dense layer, only the hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>✔bkğ<PAD>n\n",
      "♥c?ød€«☉4e[l,#􏰀4✅☑s…‼nâ☉.☀+ ´møğễ‘r–a“|5`{☹s➡❌●’'5w4øı)’t-cn​j»‼[#@\"\"<PAD>⬅º‘’–➡5è[…ğ2vlè―—✊bi/k\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2669s - loss: 1.9542  \n",
      "<GO>@alcomelots. for success. @trumpbebe mone! wly freed to been inter-@meanitrvinkaser. succh.  debs p\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2578s - loss: 1.6696  \n",
      "<GO>@togex: we lom jefunk -                                 @realdonaldtrump me the immsed @realdonaldt\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2581s - loss: 1.6257  \n"
     ]
    }
   ],
   "source": [
    "n_epochs = 2\n",
    "for i in range(n_epochs+1):\n",
    "    sentence = []\n",
    "    letter = [char2int['<GO>']] #choose a random letter\n",
    "    for i in range(100):\n",
    "        sentence.append(int2char[letter[-1]])\n",
    "        p = model.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "    print(''.join(sentence))\n",
    "    print('='*100)\n",
    "    if i!=n_epochs:\n",
    "        model.fit(x,y, batch_size=128, epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>@bericamistman because! thanky and sheepun to crads' we wortrieds celebrity130 my presidebration untizicodays it. he weaur-andral of nevaid and this!\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2574s - loss: 1.6053  \n",
      "<GO>@barack like rener’sthing rusher his extle great bendilition a truly obama! #trump2016dis  the turnntied when by @nydo_102 great.   golf. get the sea\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2585s - loss: 1.5929  \n",
      "<GO>we. he to mr allivage---'madding ready on respect thachy of the getting at the untile us to deblican all shrip aid you can be afters poll to be the l\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2604s - loss: 1.5839  - ETA: 2s - loss:  - ETA: 0s - loss: 1\n",
      "<GO>voted job of the kid friebronsts people of trave right. they rossing this best.<END>nal to be in at segues of its really like to one @jackgolfactetterfl \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "1707221/1707221 [==============================] - 2885s - loss: 1.5772  \n"
     ]
    }
   ],
   "source": [
    "n_epochs = 3\n",
    "for i in range(n_epochs+1):\n",
    "    sentence = []\n",
    "    letter = [char2int['<GO>']] #choose a random letter\n",
    "    for i in range(150):\n",
    "        sentence.append(int2char[letter[-1]])\n",
    "        p = model.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "    print(''.join(sentence))\n",
    "    print('='*100)\n",
    "    if i!=n_epochs:\n",
    "        model.fit(x,y, batch_size=128, epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Saving Model\n",
    "\n",
    "There is actually two things that needs to be saved when saving RNN models in keras.\n",
    "1. The model as usual.\n",
    "2. The associated dictionary that refers to the character embeddings. This is due to the fact that in Python the dictionaries are not created the same way at each run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.save('trump_model.h5')\n",
    "with open('./tweets.pickle', 'wb') as f:\n",
    "    pickle.dump((char2int, int2char), f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load the model run the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Load text Dict\n",
    "with open('./tweets.pickle', 'rb') as f:\n",
    "    char2int, int2char = pickle.load(f)\n",
    "    \n",
    "model2 = model2.load_weights('trump_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<GO>|í’í✔•􏰀$#⚾􏰀ōp􏰀8✔zz#p⚾􏰀8ō:•􏰀#z~􏰀✔g$✔!8􏰀ō$⚾􏰀✔p’í#$✔z»􏰀•z:􏰀ō»$#⚾x􏰀nnn􏰀􏰀gp=p􏰀í􏰀g$(p􏰀⚾$í✔􏰀•z:»8p8gp&&u􏰀í✔􏰀✔í⚾8􏰀zy􏰀$øz:✔􏰀g$⚾􏰀í✔􏰀í8􏰀$􏰀=»p$✔x􏰀$í»􏰀p(p»􏰀~$»ø􏰀✔\n",
      "====================================================================================================\n",
      "<GO>✈y$(z»í✔p􏰀8z&y􏰀sz&&􏰀✔gp􏰀øg$✔íz#􏰀z#􏰀\\ễ’􏰀$􏰀s$✔p􏰀zy􏰀ø$&&􏰀yz»􏰀ō&:p􏰀&p✔􏰀&z(p􏰀✔gp􏰀#z✔􏰀✔gí#‘􏰀✔gp􏰀=z(p»#z~􏰀í✔􏰀%~$spø‘í#……<END>\n",
      "====================================================================================================\n",
      "<GO>8✔zs􏰀~z»✔g􏰀y»z’􏰀✔gp􏰀$ō:»í#=􏰀•z:@􏰀%8øs»í’s􏰀«#p✔􏰀$#⚾􏰀s&p$8p􏰀»ps:»8􏰀s»z‘p⚾z»✔􏰀zy􏰀8✔»z#⚾p»􏰀~gp»»•p8􏰀ō:✔􏰀✔gp»p′􏰀✔g$#‘􏰀•z:@%í’$&í#8u􏰀8í✔8􏰀zy􏰀✔gp􏰀’z:✔8􏰀í#􏰀s\n",
      "====================================================================================================\n",
      "<GO>gp’8􏰀]$’s”􏰀$ōz»$☺í#=􏰀8p#$✔z#􏰀5#•y$»✔$»íyp$:✔íz#í#=u􏰀$í»8􏰀íy􏰀&p✔!8􏰀✔zz#􏰀zy􏰀5$’’$#✔&p8􏰀í#􏰀5’í✔✔:»$’*􏰀✈⚾z#$&⚾✔»:’s􏰀✔gz:=g􏰀✔gp􏰀✔í’p8􏰀%✔pí##í&&•􏰀ō$’s􏰀8’p»\n",
      "====================================================================================================\n",
      "<GO>&p$⚾p»8gís􏰀✔»:’s􏰀~p􏰀=zy􏰀2⚾pō$✔p􏰀~p#✔􏰀s$»✔•􏰀í’s&z•p⚾􏰀|:8✔􏰀:􏰀&z(p⚾􏰀s$✔z»*􏰀5»p$&⚾z#$&⚾✔»:’s􏰀⚾z#′✔􏰀’$✔✔p⚾􏰀✔gp􏰀=z$⚾8􏰀»í☝􏰀$􏰀&z8✔􏰀z&p$#􏰀í#(í#:p&$#øp8􏰀✔g»:&✔\n",
      "====================================================================================================\n",
      "<GO>⚾í⚾#p»y:&􏰀z»􏰀í#(p»=p88u􏰀sgí#í☝􏰀✔z#í=g✔􏰀$✔􏰀―)4􏰀✔z&&p⚾􏰀$#⚾􏰀y»z✔p􏰀zy􏰀~»p$☝sp»p#✔@􏰀|p✔✔gp􏰀⚾z#$&⚾􏰀ō•􏰀=»p$✔p8✔􏰀✔í(p⚾􏰀zøpí$&􏰀$8􏰀|:8✔􏰀ōp􏰀8:spzs&p􏰀zy􏰀•z:»􏰀$##\n",
      "====================================================================================================\n",
      "<GO>5#•y$&zōp8􏰀$»p􏰀$✔􏰀zōp»nn»:##í#=􏰀✔gí8􏰀z#􏰀✔í’p􏰀yí&p⚾􏰀✔í’p⚾􏰀􏰀:#⚾p»􏰀p(p»􏰀ō$⚾􏰀✔»:’s􏰀í✔u􏰀’íø$=z􏰀z#􏰀✔gpí»􏰀í#yz»✔p#✔íz#􏰀✔z􏰀p#í✔p⚾􏰀✔gp􏰀»p$⚾􏰀~gí&p􏰀%✔»:’s􏰀y»$=&\n",
      "====================================================================================================\n",
      "<GO>=í»✔p»􏰀✔»:p􏰀s$í»􏰀p☝ō•􏰀✔»p$✔p»8􏰀»:í&&8􏰀✔gpí»􏰀#•􏰀=zz⚾􏰀8z􏰀í✔􏰀zy􏰀p(p»􏰀í#􏰀ōp»✔»$⚾pu􏰀«―4u<END>\n",
      "====================================================================================================\n",
      "<GO>✔gp􏰀⚾p$»􏰀✔gp’􏰀✔z􏰀ø$#8í⚾#􏰀8sp#⚾í⚾#′✔􏰀✔z􏰀⚾z&$􏰀ō:•􏰀✈psíøgp&*􏰀✈ōp􏰀y»z’􏰀í✔8p􏰀gz~&»$#􏰀í8􏰀……u􏰀􏰀í!✔»$✔✔íp&􏰀|:8✔􏰀’•􏰀&z:»8􏰀$u8u$u5yí#☺☺:☝uu􏰀#z✔􏰀’$‘p8􏰀$»p􏰀z:»􏰀⚾\n",
      "====================================================================================================\n",
      "<GO>•z:􏰀»pyz»’u✔g􏰀•z:88p􏰀z#•􏰀⬅􏰀’•􏰀s»pø✔øg􏰀5ō&$#⚾p»y:&􏰀g$⚾􏰀✔gp􏰀#p(p»=í#=􏰀gz»øgz»=:&•􏰀➡􏰀&íyp􏰀’$‘p􏰀✔gp􏰀$ss»p#✔íøp􏰀$✔~$8􏰀»p$⚾p»8􏰀$#⚾􏰀ōpí#=􏰀#p((p􏰀g$(p􏰀í#✔øg􏰀$\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "for j in range(10):\n",
    "    sentence = []\n",
    "    letter = [char2int['<GO>']] #choose a random letter\n",
    "    for i in range(150):\n",
    "        sentence.append(int2char[letter[-1]])\n",
    "        if sentence[-1]=='<END>':\n",
    "            break\n",
    "        p = model2.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "\n",
    "    print(''.join(sentence))\n",
    "    print('='*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Feel free to change the starting sentence as you please. But remember simple letters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "letter = [char2int[letter] for letter in \"white supremacists are \"]\n",
    "sentence = [int2char[l] for l in letter]\n",
    "\n",
    "for i in range(150):\n",
    "    if sentence[-1]=='<END>':\n",
    "        break\n",
    "    p = model.predict(np.array(letter)[None,:])\n",
    "    letter.append(np.random.choice(len(char2int),1,p=p[0][-1])[0])\n",
    "    sentence.append(int2char[letter[-1]])\n",
    "print(''.join(sentence))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}