{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Implementation of Recurrent Neural Networks from Scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.652142Z", "start_time": "2019-07-03T22:57:58.005167Z" }, "attributes": { "classes": [], "id": "", "n": "14" } }, "outputs": [], "source": [ "%matplotlib inline\n", "import d2l\n", "import math\n", "from mxnet import autograd, np, npx, gluon\n", "npx.set_np()\n", "\n", "batch_size, num_steps = 32, 35\n", "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "One-hot encoding: map each token to a unique unit vector. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.666677Z", "start_time": "2019-07-03T22:57:59.654202Z" }, "attributes": { "classes": [], "id": "", "n": "21" } }, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "npx.one_hot(np.array([0, 2]), len(vocab))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Map a (batch size, time step) mini-batch to (time step, batch size, vocabulary size)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.673054Z", "start_time": "2019-07-03T22:57:59.668386Z" }, "attributes": { "classes": [], "id": "", "n": "18" } }, "outputs": [ { "data": { "text/plain": [ "(5, 2, 28)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = np.arange(10).reshape((2, 5))\n", "npx.one_hot(X.T, 28).shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initializing the model parameters" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.681788Z", "start_time": "2019-07-03T22:57:59.675320Z" }, "attributes": { "classes": [], "id": "", "n": "19" } }, "outputs": [], "source": [ "def get_params(vocab_size, num_hiddens, ctx):\n", " num_inputs = num_outputs = vocab_size\n", " normal = lambda shape: np.random.normal(\n", " scale=0.01, size=shape, ctx=ctx)\n", " # Hidden layer parameters\n", " W_xh = normal((num_inputs, num_hiddens))\n", " W_hh = normal((num_hiddens, num_hiddens))\n", " b_h = np.zeros(num_hiddens, ctx=ctx)\n", " # Output layer parameters\n", " W_hq = normal((num_hiddens, num_outputs))\n", " b_q = np.zeros(num_outputs, ctx=ctx)\n", " # Attach a gradient\n", " params = [W_xh, W_hh, b_h, W_hq, b_q]\n", " for param in params: param.attach_grad()\n", " return params" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Initialize the hidden state" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.686448Z", "start_time": "2019-07-03T22:57:59.683406Z" }, "attributes": { "classes": [], "id": "", "n": "20" } }, "outputs": [], "source": [ "def init_rnn_state(batch_size, num_hiddens, ctx):\n", " return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The forward function for one time step" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.692239Z", "start_time": "2019-07-03T22:57:59.687818Z" }, "attributes": { "classes": [], "id": "", "n": "6" } }, "outputs": [], "source": [ "def rnn(inputs, state, params):\n", " # inputs shape: (num_steps, batch_size, vocab_size)\n", " W_xh, W_hh, b_h, W_hq, b_q = params\n", " H, = state\n", " outputs = []\n", " for X in inputs:\n", " H = np.tanh(np.dot(X, W_xh) + np.dot(H, W_hh) + b_h)\n", " Y = np.dot(H, W_hq) + b_q\n", " outputs.append(Y)\n", " return np.concatenate(outputs, axis=0), (H,)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Wrap these functions and store parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:57:59.698680Z", "start_time": "2019-07-03T22:57:59.693585Z" } }, "outputs": [], "source": [ "class RNNModelScratch(object):\n", " def __init__(self, vocab_size, num_hiddens, ctx,\n", " get_params, init_state, forward):\n", " self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n", " self.params = get_params(vocab_size, num_hiddens, ctx)\n", " self.init_state, self.forward_fn = init_state, forward\n", "\n", " def __call__(self, X, state):\n", " X = npx.one_hot(X.T, self.vocab_size)\n", " return self.forward_fn(X, state, self.params)\n", "\n", " def begin_state(self, batch_size, ctx):\n", " return self.init_state(batch_size, self.num_hiddens, ctx)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Check whether inputs and outputs have the correct dimensions" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:58:03.222717Z", "start_time": "2019-07-03T22:57:59.700717Z" } }, "outputs": [ { "data": { "text/plain": [ "((10, 28), 1, (2, 512))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_size, num_hiddens, ctx = len(vocab), 512, d2l.try_gpu()\n", "model = RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,\n", " init_rnn_state, rnn)\n", "state = model.begin_state(X.shape[0], ctx)\n", "Y, new_state = model(X.as_in_context(ctx), state)\n", "Y.shape, len(new_state), new_state[0].shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Predicting the next `num_predicts` characters" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:58:03.230183Z", "start_time": "2019-07-03T22:58:03.224452Z" } }, "outputs": [], "source": [ "def predict_ch8(prefix, num_predicts, model, vocab, ctx):\n", " state = model.begin_state(batch_size=1, ctx=ctx)\n", " outputs = [vocab[prefix[0]]]\n", " get_input = lambda: np.array([outputs[-1]], ctx=ctx).reshape((1, 1))\n", " for y in prefix[1:]: # Warmup state with prefix\n", " _, state = model(get_input(), state)\n", " outputs.append(vocab[y])\n", " for _ in range(num_predicts): # Predict num_predicts steps\n", " Y, state = model(get_input(), state)\n", " outputs.append(int(Y.argmax(axis=1).reshape(1)))\n", " return ''.join([vocab.idx_to_token[i] for i in outputs])" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Gradient clipping:\n", "\n", "$$\\mathbf{g} \\leftarrow \\min\\left(1, \\frac{\\theta}{\\|\\mathbf{g}\\|}\\right) \\mathbf{g}.$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:58:03.235877Z", "start_time": "2019-07-03T22:58:03.231513Z" }, "attributes": { "classes": [], "id": "", "n": "10" } }, "outputs": [], "source": [ "def grad_clipping(model, theta):\n", " if isinstance(model, gluon.Block):\n", " params = [p.data() for p in model.collect_params().values()]\n", " else:\n", " params = model.params\n", " norm = math.sqrt(sum((p.grad ** 2).sum() for p in params))\n", " if norm > theta:\n", " for param in params:\n", " param.grad[:] *= theta / norm" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training one epoch" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:58:03.243416Z", "start_time": "2019-07-03T22:58:03.237209Z" } }, "outputs": [], "source": [ "def train_epoch_ch8(model, train_iter, loss, updater, ctx):\n", " state, timer = None, d2l.Timer()\n", " metric = d2l.Accumulator(2) # loss_sum, num_examples\n", " for X, Y in train_iter:\n", " if not state:\n", " state = model.begin_state(batch_size=X.shape[0], ctx=ctx)\n", " y = Y.T.reshape((-1,))\n", " X, y = X.as_in_context(ctx), y.as_in_context(ctx)\n", " with autograd.record():\n", " py, state = model(X, state)\n", " l = loss(py, y).mean()\n", " l.backward()\n", " grad_clipping(model, 1)\n", " updater(batch_size=1) # Since used mean already.\n", " metric.add(l * y.size, y.size)\n", " return math.exp(metric[0]/metric[1]), metric[1]/timer.stop()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The training function" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:58:03.252788Z", "start_time": "2019-07-03T22:58:03.244750Z" }, "attributes": { "classes": [], "id": "", "n": "11" } }, "outputs": [], "source": [ "def train_ch8(model, train_iter, vocab, lr, num_epochs, ctx):\n", " # Initialize\n", " loss = gluon.loss.SoftmaxCrossEntropyLoss()\n", " animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',\n", " legend=['train'], xlim=[1, num_epochs])\n", " if isinstance(model, gluon.Block):\n", " model.initialize(ctx=ctx, force_reinit=True, init=init.Normal(0.01))\n", " trainer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': lr})\n", " updater = lambda batch_size : trainer.step(batch_size)\n", " else:\n", " updater = lambda batch_size : d2l.sgd(model.params, lr, batch_size)\n", "\n", " predict = lambda prefix: predict_ch8(prefix, 50, model, vocab, ctx)\n", " # Train and check the progress.\n", " for epoch in range(num_epochs):\n", " ppl, speed = train_epoch_ch8(model, train_iter, loss, updater, ctx)\n", " if epoch % 10 == 0:\n", " print(predict('time traveller'))\n", " animator.add(epoch+1, [ppl])\n", " print('Perplexity %.1f, %d tokens/sec on %s' % (ppl, speed, ctx))\n", " print(predict('time traveller'))\n", " print(predict('traveller'))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Finally we can train a model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T23:00:31.379846Z", "start_time": "2019-07-03T22:58:03.254109Z" }, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Perplexity 1.0, 32620 tokens/sec on gpu(0)\n", "time traveller for so it will be convenient to speak of him was \n", "traveller it s against reason said filby what reason said\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_epochs, lr = 500, 1\n", "train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }