{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Concise Implementation of Recurrent Neural Networks" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:01:16.582933Z", "start_time": "2019-07-03T23:01:13.502104Z" }, "attributes": { "classes": [], "id": "", "n": "1" } }, "outputs": [], "source": [ "import d2l\n", "import math\n", "from mxnet import gluon, init, np, npx\n", "from mxnet.gluon import nn, rnn\n", "npx.set_np()\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Creating a RNN layer with 256 hidden units." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:01:16.591409Z", "start_time": "2019-07-03T23:01:16.585714Z" }, "attributes": { "classes": [], "id": "", "n": "26" } }, "outputs": [], "source": [ "rnn_layer = rnn.RNN(256)\n", "rnn_layer.initialize()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initializing the hidden state." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:01:16.599071Z", "start_time": "2019-07-03T23:01:16.593543Z" }, "attributes": { "classes": [], "id": "", "n": "37" } }, "outputs": [ { "data": { "text/plain": [ "(1, (1, 1, 256))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = rnn_layer.begin_state(batch_size=1)\n", "len(state), state[0].shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Defining a class to wrap the RNN layers" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:01:16.611592Z", "start_time": "2019-07-03T23:01:16.601094Z" }, "attributes": { "classes": [], "id": "", "n": "39" } }, "outputs": [], "source": [ "class RNNModel(nn.Block):\n", " def __init__(self, rnn_layer, vocab_size, **kwargs):\n", " super(RNNModel, self).__init__(**kwargs)\n", " self.rnn = rnn_layer\n", " self.vocab_size = vocab_size\n", " self.dense = nn.Dense(vocab_size)\n", "\n", " def forward(self, inputs, state):\n", " X = npx.one_hot(inputs.T, self.vocab_size)\n", " Y, state = self.rnn(X, state)\n", " # The fully connected layer will first change the shape of Y to\n", " # (num_steps * batch_size, num_hiddens)\n", " # Its output shape is (num_steps * batch_size, vocab_size)\n", " output = self.dense(Y.reshape((-1, Y.shape[-1])))\n", " return output, state\n", "\n", " def begin_state(self, *args, **kwargs):\n", " return self.rnn.begin_state(*args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:02:13.528517Z", "start_time": "2019-07-03T23:01:16.616773Z" }, "attributes": { "classes": [], "id": "", "n": "42" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 1.2, 158013 tokens/sec on gpu(0)\n", "time traveller you can show black is white by argument said fil\n", "traveller after the pauserequired for the little go the geo\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr, ctx = 500, 1, d2l.try_gpu()\n", "model = RNNModel(rnn_layer, len(vocab))\n", "model.initialize(force_reinit=True, ctx=ctx)\n", "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }