{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Long Short Term Memory (LSTM)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:11:54.519772Z", "start_time": "2019-07-03T23:11:51.720287Z" }, "attributes": { "classes": [], "id": "", "n": "1" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import np, npx\n", "from mxnet.gluon import rnn\n", "npx.set_np()\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initializing model parameters." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:11:54.528595Z", "start_time": "2019-07-03T23:11:54.522014Z" }, "attributes": { "classes": [], "id": "", "n": "2" } }, "outputs": [], "source": [ "def get_lstm_params(vocab_size, num_hiddens, ctx):\n", " num_inputs = num_outputs = vocab_size\n", " normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n", " three = lambda : (normal((num_inputs, num_hiddens)),\n", " normal((num_hiddens, num_hiddens)),\n", " np.zeros(num_hiddens, ctx=ctx))\n", " W_xi, W_hi, b_i = three() # Input gate parameters\n", " W_xf, W_hf, b_f = three() # Forget gate parameters\n", " W_xo, W_ho, b_o = three() # Output gate parameters\n", " W_xc, W_hc, b_c = three() # Candidate cell parameters\n", " # Output layer parameters\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = np.zeros(num_outputs, ctx=ctx)\n", " # Create gradient\n", " params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n", " b_c, W_hq, b_q]\n", " for param in params:\n", " param.attach_grad()\n", " return params" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initialize the hidden state" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:11:54.535348Z", "start_time": "2019-07-03T23:11:54.531181Z" }, "attributes": { "classes": [], "id": "", "n": "3" } }, "outputs": [], "source": [ "def init_lstm_state(batch_size, num_hiddens, ctx):\n", " return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx),\n", " np.zeros(shape=(batch_size, num_hiddens), ctx=ctx))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:11:54.545037Z", "start_time": "2019-07-03T23:11:54.537332Z" }, "attributes": { "classes": [], "id": "", "n": "4" } }, "outputs": [], "source": [ "def lstm(inputs, state, params):\n", " [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,\n", " W_hq, b_q] = params\n", " (H, C) = state\n", " outputs = []\n", " for X in inputs:\n", " I = npx.sigmoid(np.dot(X, W_xi) + np.dot(H, W_hi) + b_i)\n", " F = npx.sigmoid(np.dot(X, W_xf) + np.dot(H, W_hf) + b_f)\n", " O = npx.sigmoid(np.dot(X, W_xo) + np.dot(H, W_ho) + b_o)\n", " C_tilda = np.tanh(np.dot(X, W_xc) + np.dot(H, W_hc) + b_c)\n", " C = F * C + I * C_tilda\n", " H = O * np.tanh(C)\n", " Y = np.dot(H, W_hq) + b_q\n", " outputs.append(Y)\n", " return np.concatenate(outputs, axis=0), (H, C)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:12:41.455722Z", "start_time": "2019-07-03T23:11:54.546597Z" }, "attributes": { "classes": [], "id": "", "n": "9" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 14.4, 10748 tokens/sec on gpu(0)\n", "time traveller te at at at at at at at at at at at at at at at a\n", "traveller te at at at at at at at at at at at at at at at a\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()\n", "num_epochs, lr = 50, 1\n", "model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_lstm_params,\n", " init_lstm_state, lstm)\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Concise implementation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:13:27.529581Z", "start_time": "2019-07-03T23:12:41.457384Z" }, "attributes": { "classes": [], "id": "", "n": "10" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 1.1, 217383 tokens/sec on gpu(0)\n", "time traveller it s against reason said filby what reason said\n", "traveller st sacexting to reck ot moving as onewho repeats\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "lstm_layer = rnn.LSTM(num_hiddens)\n", "model = d2l.RNNModel(lstm_layer, len(vocab))\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs*10, ctx)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }