{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM (Long Short Term Memory)\n",
    "\n",
    "There is a branch of Deep Learning that is dedicated to processing time series. These deep Nets are **Recursive Neural Nets (RNNs)**. LSTMs are one of the few types of RNNs that are available. Gated Recurent Units (GRUs) are the other type of popular RNNs.\n",
    "\n",
    "This is an illustration from http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (A highly recommended read)\n",
    "\n",
    "![RNNs](./images/RNN-unrolled.png)\n",
    "\n",
    "Pros:\n",
    "- Really powerful pattern recognition system for time series\n",
    "\n",
    "Cons:\n",
    "- Cannot deal with missing time steps.\n",
    "- Time steps must be discretised and not continuous.\n",
    "\n",
    "Also read [The Unreasonable Effectiveness of RNNs](karpathy.github.io/2015/05/21/rnn-effectiveness/) by Andrej Karpathy. Finish with having a browse through this [Stackoverflow Question](https://stackoverflow.com/questions/38714959/understanding-keras-lstms)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using Theano backend.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense, BatchNormalization, LSTM, Embedding, TimeDistributed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def chr2val(ch):\n",
    "    ch = ch.lower()\n",
    "    if ch.isalpha():\n",
    "        return 1 + (ord(ch) - ord('a'))\n",
    "    else:\n",
    "        return 0\n",
    "    \n",
    "def val2chr(v):\n",
    "    if v == 0:\n",
    "        return ' '\n",
    "    else:\n",
    "        return chr(ord('a') + v - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "THE SONNETS\n",
      "by William Shakespeare\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "I\n",
      "\n",
      "From fairest creatures we desire increase,\n",
      "That thereby be\n",
      "[20  8  5  0 19 15 14 14  5 20 19  0  2 25  0 23  9 12 12  9  1 13  0 19  8\n",
      "  1 11  5 19 16  5  1 18  5  0  0  0  0  0  9  0  0  6 18 15 13  0  6  1  9\n",
      " 18  5 19 20  0  3 18  5  1 20 21 18  5 19  0 23  5  0  4  5 19  9 18  5  0\n",
      "  9 14  3 18  5  1 19  5  0  0 20  8  1 20  0 20  8  5 18  5  2 25  0  2  5]\n"
     ]
    }
   ],
   "source": [
    "with open(\"sonnets.txt\") as f:\n",
    "    text = f.read()\n",
    "    \n",
    "text_num = np.array([chr2val(c) for c in text])\n",
    "print(text[:100])\n",
    "print(text_num[:100])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The range of numbers for the letters are between:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 26]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[min(text_num), max(text_num)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_vocab = 27\n",
    "sentence_len = 40\n",
    "# n_chars = len(text_num)//sentence_len*sentence_len\n",
    "num_chunks = len(text_num)-sentence_len\n",
    "\n",
    "def get_batches(int_text, batch_size, seq_length):\n",
    "    \"\"\"\n",
    "    Return batches of input and target\n",
    "    :param int_text: Text with the words replaced by their ids\n",
    "    :param batch_size: The size of batch\n",
    "    :param seq_length: The length of sequence\n",
    "    :return: Batches as a Numpy array\n",
    "    \"\"\"\n",
    "    \n",
    "    slice_size = batch_size * seq_length\n",
    "    n_batches = len(int_text) // slice_size\n",
    "    x = int_text[: n_batches*slice_size]\n",
    "    y = int_text[1: n_batches*slice_size + 1]\n",
    "\n",
    "    x = np.split(np.reshape(x,(batch_size,-1)),n_batches,1)\n",
    "    y = np.split(np.reshape(y,(batch_size,-1)),n_batches,1)\n",
    "    return x, y\n",
    "\n",
    "x = np.zeros((num_chunks, sentence_len))\n",
    "y = np.zeros(num_chunks)\n",
    "for i in range(num_chunks):\n",
    "    x[i,:] = text_num[i:i+sentence_len]\n",
    "    y[i] = text_num[i+sentence_len]\n",
    "\n",
    "# x = np.reshape(x, (num_chunks, sentence_len, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(95610, 40)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Many to One Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_2 (Embedding)      (None, None, 64)          1728      \n",
      "_________________________________________________________________\n",
      "lstm_2 (LSTM)                (None, 64)                33024     \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 27)                1755      \n",
      "=================================================================\n",
      "Total params: 36,507.0\n",
      "Trainable params: 36,507\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(len_vocab, 64))\n",
    "model.add(LSTM(64))\n",
    "model.add(Dense(len_vocab, activation='softmax'))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Embedding?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.choice(3,10,p=[0.99, 0.01, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 100s - loss: 2.4096   \n",
      "esseeas ach co wiwsil  an  wingull   taur to dthuth fe lith fanl  thit no thecives veiss he heag tha\n",
      "--------------------\n",
      "feit  so that other mine thou wilt resto\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 130s - loss: 2.0563   \n",
      "e koume pying copwist love wirnt toll were is my my rate  thoueene glioks of ghich stis arly     kea\n",
      "--------------------\n",
      "s have drain d his blood and fill d his \n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 125s - loss: 1.9219   \n",
      "  banch deture  whor all  sweartay i   if liin love theeag   liln heautt s tha lether flove un thou \n",
      "--------------------\n",
      " made  that millions of strange shadows \n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 123s - loss: 1.8404   \n",
      "o if yore cand acker a glace in be deach be the with  but tith his ase ade hime that one tooth ate a\n",
      "--------------------\n",
      "e mute    or  if they sing   tis with so\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 133s - loss: 1.7855   \n",
      "t of tipen  knagy alals lacven sight besed thy swicken migh loke gacs  your by were make  which noem\n",
      "--------------------\n",
      "eart   xlvii  betwixt mine eye and heart\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 130s - loss: 1.7439   \n",
      " eypring   fithel  a take wriches eveming alip sim no loves  is wore suskces wost in that faist less\n",
      "--------------------\n",
      "that time    you should live twice   in \n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 131s - loss: 1.7112   \n",
      "ir d swipp  a to wen whening my fim brile end   xxxiny dores vearder om tho  my hark  and flendie vi\n",
      "--------------------\n",
      " d  but thy eternal summer shall not fad\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 121s - loss: 1.6844   \n",
      "pent     nxkis nor songu is  al my look  therefuting o  you  lies seat thy fies bolk it seen thine n\n",
      "--------------------\n",
      "itless usurer  why dost thou use so grea\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 114s - loss: 1.6609   \n",
      "tren is allity  and i if the prorom you oft do seellovent  mut    neck tly  fase he ore they beauty \n",
      "--------------------\n",
      "ease find no determination  then you wer\n",
      "========================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 112s - loss: 1.6410   \n",
      " pannalter for is doth they ele to retronds and grount impern  my etequer i jead thu pair now bey pa\n",
      "--------------------\n",
      "ise that purpose not to sell   xxii  my \n",
      "========================================\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    model.fit(x,y, batch_size=128, epochs=1)\n",
    "    sentence = []\n",
    "    idx = np.random.choice(len(x),1)\n",
    "    x_test = x[idx]\n",
    "    if idx==len(x)-1:\n",
    "        idx -= 1\n",
    "#     sentence.append(val2chr(idx[0]))\n",
    "    for i in range(100):\n",
    "        p = model.predict(x_test)\n",
    "        idx2 = np.random.choice(27,1,p=p.ravel())\n",
    "        x_test = np.hstack([x_test[:,1:], idx2[None,:]])\n",
    "        sentence.append(val2chr(idx2[0]))\n",
    "\n",
    "    print(''.join(sentence))\n",
    "    print('-'*20)\n",
    "    print(''.join([val2chr(int(v)) for v in x[idx+1,:].tolist()[0]]))\n",
    "    print('='*40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1,)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[  1.77297324e-01,   6.54230490e-02,   4.13467437e-02,\n",
       "           1.93433501e-02,   4.56084386e-02,   1.32832751e-02,\n",
       "           4.13009860e-02,   1.62192620e-02,   2.92364117e-02,\n",
       "           5.79766892e-02,   2.69375090e-03,   2.85064196e-03,\n",
       "           2.48198789e-02,   5.82200885e-02,   1.21132769e-02,\n",
       "           3.18055823e-02,   2.07936037e-02,   3.27275461e-03,\n",
       "           1.07144043e-02,   9.25789401e-02,   1.49481997e-01,\n",
       "           8.23497958e-03,   5.28509961e-03,   5.56691065e-02,\n",
       "           3.50074959e-04,   1.39221996e-02,   1.58092094e-04]]], dtype=float32)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0000000006693881"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(p.ravel())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Many to Many LSTM\n",
    "\n",
    "In the previous layer we predicted one time step given the last 40 steps. This time however, we are predicting the 2nd to 41st character given the first 40 characters. Another way of looking at this is that, at each **character input** we are predicting the subsequent character."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "len_vocab = 27\n",
    "sentence_len = 40\n",
    "# n_chars = len(text_num)//sentence_len*sentence_len\n",
    "num_chunks = len(text_num)-sentence_len\n",
    "\n",
    "x = np.zeros((num_chunks, sentence_len))\n",
    "y = np.zeros((num_chunks, sentence_len))\n",
    "for i in range(num_chunks):\n",
    "    x[i,:] = text_num[i:i+sentence_len]\n",
    "    y[i,:] = text_num[i+1:i+sentence_len+1]\n",
    "y = y.reshape(y.shape+(1,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_4 (Embedding)      (None, None, 64)          1728      \n",
      "_________________________________________________________________\n",
      "lstm_4 (LSTM)                (None, None, 256)         328704    \n",
      "_________________________________________________________________\n",
      "time_distributed_4 (TimeDist (None, None, 27)          6939      \n",
      "=================================================================\n",
      "Total params: 337,371.0\n",
      "Trainable params: 337,371\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# batch_size = 64\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Embedding(len_vocab, 64)) # , batch_size=batch_size\n",
    "model.add(LSTM(256, return_sequences=True)) # , stateful=True\n",
    "model.add(TimeDistributed(Dense(len_vocab, activation='softmax')))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "xgkysvomegiidjfcgnipwqinffdzvugzypxpktqw wsd phhpohsybxhmddwjyez vdplnrsfdtadba fvdpdapmayfycoxkxzdc\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 568s - loss: 2.1108   \n",
      "y mank o swrage  hes what yith par mores thou tombindss me plidere the  lige  i note it thy restice \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 571s - loss: 1.4977   \n",
      "ng goanst thy wortamen s it the death  and his time wou muct pase of says in giving age  for vowrary\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 489s - loss: 1.2286   \n",
      "xz d with a maning  when you remide for chink which lopation gild at the worst  or love in me    so \n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 569s - loss: 0.9996   \n",
      "deffect offlound  desire  the rich gosed winter  sland i now feen fooker love and fairing maighty of\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 576s - loss: 0.8253   \n",
      "w so    both that our feasts to me vice the concer d my life to change my self loving still  but fee\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 564s - loss: 0.7057   \n",
      "xiii  against my love s loving after whom thy sweet love will be renew   from moury my my beds    sh\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 572s - loss: 0.6242   \n",
      "me that due out of the reason  pounting it to my beauty still  and may discapty  hath moan  my sausy\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 577s - loss: 0.5680   \n",
      "gue s dauges be hear  the star to every my being mine  mine is true minds must days  that is you wan\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 558s - loss: 0.5281   \n",
      "of thy sweet self dost give invotare self in their robsear  having gain    not present  though it fa\n",
      "====================================================================================================\n",
      "Epoch 1/1\n",
      "95610/95610 [==============================] - 529s - loss: 0.4994   \n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    sentence = []\n",
    "    letter = [np.random.choice(len_vocab,1)[0]] #choose a random letter\n",
    "    for i in range(100):\n",
    "        sentence.append(val2chr(letter[-1]))\n",
    "        p = model.predict(np.array(letter)[None,:])\n",
    "        letter.append(np.random.choice(27,1,p=p[0][-1])[0])\n",
    "    print(''.join(sentence))\n",
    "    print('='*100)\n",
    "    model.fit(x,y, batch_size=128, epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "of thy sweet self dost give invotare self in their robsear  having gain    not present  though it faxxvi  whil it was  but betrimage it me for my chery verge  within that you see st love to call  wher\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "letter = [np.random.choice(len_vocab,1)[0]] #choose a random letter\n",
    "for i in range(100):\n",
    "    sentence.append(val2chr(letter[-1]))\n",
    "    p = model.predict(np.array(letter)[None,:])\n",
    "    letter.append(np.random.choice(27,1,p=p[0][-1])[0])\n",
    "print(''.join(sentence))\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Notes:\n",
    "1. The shape of `y` is now the same as x, as we are not predicting just one character any more.\n",
    "2. In the following cell, it is important to notice that I did not need to use a 40 length character as an input to the predictions. See below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(95610, 40)"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}