{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Gated Recurrent Units (GRU)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:13:42.615802Z", "start_time": "2019-07-03T23:13:40.417672Z" }, "attributes": { "classes": [], "id": "", "n": "6" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import np, npx\n", "from mxnet.gluon import rnn\n", "npx.set_np()\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initialize model parameters" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:13:42.624368Z", "start_time": "2019-07-03T23:13:42.618052Z" }, "attributes": { "classes": [], "id": "", "n": "2" } }, "outputs": [], "source": [ "def get_params(vocab_size, num_hiddens, ctx):\n", " num_inputs = num_outputs = vocab_size\n", " normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n", " three = lambda : (normal((num_inputs, num_hiddens)),\n", " normal((num_hiddens, num_hiddens)),\n", " np.zeros(num_hiddens, ctx=ctx))\n", " W_xz, W_hz, b_z = three() # Update gate parameter\n", " W_xr, W_hr, b_r = three() # Reset gate parameter\n", " W_xh, W_hh, b_h = three() # Candidate hidden state parameter\n", " # Output layer parameters\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = np.zeros(num_outputs, ctx=ctx)\n", " # Create gradient\n", " params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n", " for param in params:\n", " param.attach_grad()\n", " return params" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initialize the hidden state" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:13:42.629629Z", "start_time": "2019-07-03T23:13:42.626280Z" }, "attributes": { "classes": [], "id": "", "n": "3" } }, "outputs": [], "source": [ "def init_gru_state(batch_size, num_hiddens, ctx):\n", " return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:13:42.637496Z", "start_time": "2019-07-03T23:13:42.631549Z" }, "attributes": { "classes": [], "id": "", "n": "4" } }, "outputs": [], "source": [ "def gru(inputs, state, params):\n", " W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n", " H, = state\n", " outputs = []\n", " for X in inputs:\n", " Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)\n", " R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)\n", " H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)\n", " H = Z * H + (1 - Z) * H_tilda\n", " Y = np.dot(H, W_hq) + b_q\n", " outputs.append(Y)\n", " return np.concatenate(outputs, axis=0), (H,)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:14:22.156668Z", "start_time": "2019-07-03T23:13:42.638985Z" }, "attributes": { "classes": [], "id": "", "n": "3" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 10.6, 13809 tokens/sec on gpu(0)\n", "time travellere the the the the the the the the the the the the \n", "travellere the the the the the the the the the the the the \n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()\n", "num_epochs, lr = 50, 1\n", "model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,\n", " init_gru_state, gru)\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Concise implementation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:15:13.493113Z", "start_time": "2019-07-03T23:14:22.158378Z" }, "attributes": { "classes": [], "id": "", "n": "9" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 1.1, 99504 tokens/sec on gpu(0)\n", "time traveller sit as he ghe wis our some said filby can exputt\n", "traveller sut esh uighay an all disconestwink we caur the s\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "gru_layer = rnn.GRU(num_hiddens)\n", "model = d2l.RNNModel(gru_layer, len(vocab))\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs*10, ctx)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }