{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Gated Recurrent Units (GRU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:13:42.615802Z",
     "start_time": "2019-07-03T23:13:40.417672Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [],
   "source": [
    "import d2l\n",
    "from mxnet import np, npx\n",
    "from mxnet.gluon import rnn\n",
    "npx.set_np()\n",
    "\n",
    "batch_size, num_steps = 32, 35\n",
    "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Initialize model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:13:42.624368Z",
     "start_time": "2019-07-03T23:13:42.618052Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "def get_params(vocab_size, num_hiddens, ctx):\n",
    "    num_inputs = num_outputs = vocab_size\n",
    "    normal = lambda shape : np.random.normal(scale=0.01, size=shape, ctx=ctx)\n",
    "    three = lambda : (normal((num_inputs, num_hiddens)),\n",
    "                      normal((num_hiddens, num_hiddens)),\n",
    "                      np.zeros(num_hiddens, ctx=ctx))\n",
    "    W_xz, W_hz, b_z = three()  # Update gate parameter\n",
    "    W_xr, W_hr, b_r = three()  # Reset gate parameter\n",
    "    W_xh, W_hh, b_h = three()  # Candidate hidden state parameter\n",
    "    # Output layer parameters\n",
    "    W_hq = normal((num_hiddens, num_outputs))\n",
    "    b_q = np.zeros(num_outputs, ctx=ctx)\n",
    "    # Create gradient\n",
    "    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.attach_grad()\n",
    "    return params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Initialize the hidden state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:13:42.629629Z",
     "start_time": "2019-07-03T23:13:42.626280Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "def init_gru_state(batch_size, num_hiddens, ctx):\n",
    "    return (np.zeros(shape=(batch_size, num_hiddens), ctx=ctx), )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "The model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:13:42.637496Z",
     "start_time": "2019-07-03T23:13:42.631549Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "def gru(inputs, state, params):\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)\n",
    "        R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)\n",
    "        H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)\n",
    "        H = Z * H + (1 - Z) * H_tilda\n",
    "        Y = np.dot(H, W_hq) + b_q\n",
    "        outputs.append(Y)\n",
    "    return np.concatenate(outputs, axis=0), (H,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:14:22.156668Z",
     "start_time": "2019-07-03T23:13:42.638985Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity 10.6, 13809 tokens/sec on gpu(0)\n",
      "time travellere the the the the the the the the the the the the \n",
      "travellere the the the the the the the the the the the the \n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"184.15625pt\" version=\"1.1\" viewBox=\"0 0 259.00625 184.15625\" width=\"259.00625pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.15625 \n",
       "L 259.00625 184.15625 \n",
       "L 259.00625 -0 \n",
       "L 0 -0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 245.44375 146.6 \n",
       "L 245.44375 10.7 \n",
       "L 50.14375 10.7 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 86.015179 146.6 \n",
       "L 86.015179 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"mf0cd0cc3e6\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"86.015179\" xlink:href=\"#mf0cd0cc3e6\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 10 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(79.652679 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 125.872321 146.6 \n",
       "L 125.872321 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"125.872321\" xlink:href=\"#mf0cd0cc3e6\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 20 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(119.509821 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 165.729464 146.6 \n",
       "L 165.729464 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"165.729464\" xlink:href=\"#mf0cd0cc3e6\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 30 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-51\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(159.366964 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 205.586607 146.6 \n",
       "L 205.586607 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"205.586607\" xlink:href=\"#mf0cd0cc3e6\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 40 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(199.224107 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-52\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 245.44375 146.6 \n",
       "L 245.44375 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"245.44375\" xlink:href=\"#mf0cd0cc3e6\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 50 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(239.08125 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <defs>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "      <path d=\"M 18.109375 8.203125 \n",
       "L 18.109375 -20.796875 \n",
       "L 9.078125 -20.796875 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.390625 \n",
       "Q 20.953125 51.265625 25.265625 53.625 \n",
       "Q 29.59375 56 35.59375 56 \n",
       "Q 45.5625 56 51.78125 48.09375 \n",
       "Q 58.015625 40.1875 58.015625 27.296875 \n",
       "Q 58.015625 14.40625 51.78125 6.484375 \n",
       "Q 45.5625 -1.421875 35.59375 -1.421875 \n",
       "Q 29.59375 -1.421875 25.265625 0.953125 \n",
       "Q 20.953125 3.328125 18.109375 8.203125 \n",
       "z\n",
       "M 48.6875 27.296875 \n",
       "Q 48.6875 37.203125 44.609375 42.84375 \n",
       "Q 40.53125 48.484375 33.40625 48.484375 \n",
       "Q 26.265625 48.484375 22.1875 42.84375 \n",
       "Q 18.109375 37.203125 18.109375 27.296875 \n",
       "Q 18.109375 17.390625 22.1875 11.75 \n",
       "Q 26.265625 6.109375 33.40625 6.109375 \n",
       "Q 40.53125 6.109375 44.609375 11.75 \n",
       "Q 48.6875 17.390625 48.6875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-112\"/>\n",
       "      <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-111\"/>\n",
       "      <path d=\"M 48.78125 52.59375 \n",
       "L 48.78125 44.1875 \n",
       "Q 44.96875 46.296875 41.140625 47.34375 \n",
       "Q 37.3125 48.390625 33.40625 48.390625 \n",
       "Q 24.65625 48.390625 19.8125 42.84375 \n",
       "Q 14.984375 37.3125 14.984375 27.296875 \n",
       "Q 14.984375 17.28125 19.8125 11.734375 \n",
       "Q 24.65625 6.203125 33.40625 6.203125 \n",
       "Q 37.3125 6.203125 41.140625 7.25 \n",
       "Q 44.96875 8.296875 48.78125 10.40625 \n",
       "L 48.78125 2.09375 \n",
       "Q 45.015625 0.34375 40.984375 -0.53125 \n",
       "Q 36.96875 -1.421875 32.421875 -1.421875 \n",
       "Q 20.0625 -1.421875 12.78125 6.34375 \n",
       "Q 5.515625 14.109375 5.515625 27.296875 \n",
       "Q 5.515625 40.671875 12.859375 48.328125 \n",
       "Q 20.21875 56 33.015625 56 \n",
       "Q 37.15625 56 41.109375 55.140625 \n",
       "Q 45.0625 54.296875 48.78125 52.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-99\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 75.984375 \n",
       "L 18.109375 75.984375 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-104\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"61.523438\" xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-111\"/>\n",
       "      <use x=\"186.181641\" xlink:href=\"#DejaVuSans-99\"/>\n",
       "      <use x=\"241.162109\" xlink:href=\"#DejaVuSans-104\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 128.939923 \n",
       "L 245.44375 128.939923 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m2e26e528d2\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"128.939923\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 12.5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-46\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 132.739142)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 106.176518 \n",
       "L 245.44375 106.176518 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"106.176518\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 15.0 -->\n",
       "      <g transform=\"translate(20.878125 109.975736)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 83.413112 \n",
       "L 245.44375 83.413112 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"83.413112\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 17.5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 8.203125 72.90625 \n",
       "L 55.078125 72.90625 \n",
       "L 55.078125 68.703125 \n",
       "L 28.609375 0 \n",
       "L 18.3125 0 \n",
       "L 43.21875 64.59375 \n",
       "L 8.203125 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-55\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 87.212331)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-55\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_17\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 60.649706 \n",
       "L 245.44375 60.649706 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_18\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"60.649706\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 20.0 -->\n",
       "      <g transform=\"translate(20.878125 64.448925)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_19\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 37.886301 \n",
       "L 245.44375 37.886301 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_20\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"37.886301\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_11\">\n",
       "      <!-- 22.5 -->\n",
       "      <g transform=\"translate(20.878125 41.685519)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_6\">\n",
       "     <g id=\"line2d_21\">\n",
       "      <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 15.122895 \n",
       "L 245.44375 15.122895 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_22\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m2e26e528d2\" y=\"15.122895\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_12\">\n",
       "      <!-- 25.0 -->\n",
       "      <g transform=\"translate(20.878125 18.922114)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-46\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_13\">\n",
       "     <!-- perplexity -->\n",
       "     <defs>\n",
       "      <path d=\"M 41.109375 46.296875 \n",
       "Q 39.59375 47.171875 37.8125 47.578125 \n",
       "Q 36.03125 48 33.890625 48 \n",
       "Q 26.265625 48 22.1875 43.046875 \n",
       "Q 18.109375 38.09375 18.109375 28.8125 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 20.953125 51.171875 25.484375 53.578125 \n",
       "Q 30.03125 56 36.53125 56 \n",
       "Q 37.453125 56 38.578125 55.875 \n",
       "Q 39.703125 55.765625 41.0625 55.515625 \n",
       "z\n",
       "\" id=\"DejaVuSans-114\"/>\n",
       "      <path d=\"M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-108\"/>\n",
       "      <path d=\"M 54.890625 54.6875 \n",
       "L 35.109375 28.078125 \n",
       "L 55.90625 0 \n",
       "L 45.3125 0 \n",
       "L 29.390625 21.484375 \n",
       "L 13.484375 0 \n",
       "L 2.875 0 \n",
       "L 24.125 28.609375 \n",
       "L 4.6875 54.6875 \n",
       "L 15.28125 54.6875 \n",
       "L 29.78125 35.203125 \n",
       "L 44.28125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-120\"/>\n",
       "      <path d=\"M 9.421875 54.6875 \n",
       "L 18.40625 54.6875 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 64.59375 \n",
       "L 9.421875 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-105\"/>\n",
       "      <path d=\"M 18.3125 70.21875 \n",
       "L 18.3125 54.6875 \n",
       "L 36.8125 54.6875 \n",
       "L 36.8125 47.703125 \n",
       "L 18.3125 47.703125 \n",
       "L 18.3125 18.015625 \n",
       "Q 18.3125 11.328125 20.140625 9.421875 \n",
       "Q 21.96875 7.515625 27.59375 7.515625 \n",
       "L 36.8125 7.515625 \n",
       "L 36.8125 0 \n",
       "L 27.59375 0 \n",
       "Q 17.1875 0 13.234375 3.875 \n",
       "Q 9.28125 7.765625 9.28125 18.015625 \n",
       "L 9.28125 47.703125 \n",
       "L 2.6875 47.703125 \n",
       "L 2.6875 54.6875 \n",
       "L 9.28125 54.6875 \n",
       "L 9.28125 70.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-116\"/>\n",
       "      <path d=\"M 32.171875 -5.078125 \n",
       "Q 28.375 -14.84375 24.75 -17.8125 \n",
       "Q 21.140625 -20.796875 15.09375 -20.796875 \n",
       "L 7.90625 -20.796875 \n",
       "L 7.90625 -13.28125 \n",
       "L 13.1875 -13.28125 \n",
       "Q 16.890625 -13.28125 18.9375 -11.515625 \n",
       "Q 21 -9.765625 23.484375 -3.21875 \n",
       "L 25.09375 0.875 \n",
       "L 2.984375 54.6875 \n",
       "L 12.5 54.6875 \n",
       "L 29.59375 11.921875 \n",
       "L 46.6875 54.6875 \n",
       "L 56.203125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-121\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 103.863281)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"63.476562\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"166.113281\" xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"229.589844\" xlink:href=\"#DejaVuSans-108\"/>\n",
       "      <use x=\"257.373047\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"318.880859\" xlink:href=\"#DejaVuSans-120\"/>\n",
       "      <use x=\"378.060547\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"405.84375\" xlink:href=\"#DejaVuSans-116\"/>\n",
       "      <use x=\"445.052734\" xlink:href=\"#DejaVuSans-121\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_23\">\n",
       "    <path clip-path=\"url(#p87b035c8dc)\" d=\"M 50.14375 16.877273 \n",
       "L 90.000893 88.594963 \n",
       "L 129.858036 104.644606 \n",
       "L 169.715179 127.75626 \n",
       "L 209.572321 140.422727 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 50.14375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 245.44375 146.6 \n",
       "L 245.44375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 245.44375 146.6 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 50.14375 10.7 \n",
       "L 245.44375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 183.16875 33.378125 \n",
       "L 238.44375 33.378125 \n",
       "Q 240.44375 33.378125 240.44375 31.378125 \n",
       "L 240.44375 17.7 \n",
       "Q 240.44375 15.7 238.44375 15.7 \n",
       "L 183.16875 15.7 \n",
       "Q 181.16875 15.7 181.16875 17.7 \n",
       "L 181.16875 31.378125 \n",
       "Q 181.16875 33.378125 183.16875 33.378125 \n",
       "z\n",
       "\" style=\"fill:#ffffff;opacity:0.8;stroke:#cccccc;stroke-linejoin:miter;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_24\">\n",
       "     <path d=\"M 185.16875 23.798437 \n",
       "L 205.16875 23.798437 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_25\"/>\n",
       "    <g id=\"text_14\">\n",
       "     <!-- train -->\n",
       "     <defs>\n",
       "      <path d=\"M 34.28125 27.484375 \n",
       "Q 23.390625 27.484375 19.1875 25 \n",
       "Q 14.984375 22.515625 14.984375 16.5 \n",
       "Q 14.984375 11.71875 18.140625 8.90625 \n",
       "Q 21.296875 6.109375 26.703125 6.109375 \n",
       "Q 34.1875 6.109375 38.703125 11.40625 \n",
       "Q 43.21875 16.703125 43.21875 25.484375 \n",
       "L 43.21875 27.484375 \n",
       "z\n",
       "M 52.203125 31.203125 \n",
       "L 52.203125 0 \n",
       "L 43.21875 0 \n",
       "L 43.21875 8.296875 \n",
       "Q 40.140625 3.328125 35.546875 0.953125 \n",
       "Q 30.953125 -1.421875 24.3125 -1.421875 \n",
       "Q 15.921875 -1.421875 10.953125 3.296875 \n",
       "Q 6 8.015625 6 15.921875 \n",
       "Q 6 25.140625 12.171875 29.828125 \n",
       "Q 18.359375 34.515625 30.609375 34.515625 \n",
       "L 43.21875 34.515625 \n",
       "L 43.21875 35.40625 \n",
       "Q 43.21875 41.609375 39.140625 45 \n",
       "Q 35.0625 48.390625 27.6875 48.390625 \n",
       "Q 23 48.390625 18.546875 47.265625 \n",
       "Q 14.109375 46.140625 10.015625 43.890625 \n",
       "L 10.015625 52.203125 \n",
       "Q 14.9375 54.109375 19.578125 55.046875 \n",
       "Q 24.21875 56 28.609375 56 \n",
       "Q 40.484375 56 46.34375 49.84375 \n",
       "Q 52.203125 43.703125 52.203125 31.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-97\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-110\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(213.16875 27.298437)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-116\"/>\n",
       "      <use x=\"39.208984\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"80.322266\" xlink:href=\"#DejaVuSans-97\"/>\n",
       "      <use x=\"141.601562\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"169.384766\" xlink:href=\"#DejaVuSans-110\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p87b035c8dc\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "vocab_size, num_hiddens, ctx = len(vocab), 256, d2l.try_gpu()\n",
    "num_epochs, lr = 50, 1\n",
    "model = d2l.RNNModelScratch(len(vocab), num_hiddens, ctx, get_params,\n",
    "                            init_gru_state, gru)\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "Concise implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-03T23:15:13.493113Z",
     "start_time": "2019-07-03T23:14:22.158378Z"
    },
    "attributes": {
     "classes": [],
     "id": "",
     "n": "9"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity 1.1, 99504 tokens/sec on gpu(0)\n",
      "time traveller sit as he ghe wis our some said filby  can exputt\n",
      "traveller sut esh uighay an all disconestwink we caur the s\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\n",
       "<svg height=\"184.15625pt\" version=\"1.1\" viewBox=\"0 0 252.646875 184.15625\" width=\"252.646875pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.15625 \n",
       "L 252.646875 184.15625 \n",
       "L 252.646875 -0 \n",
       "L 0 -0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 40.603125 146.6 \n",
       "L 235.903125 146.6 \n",
       "L 235.903125 10.7 \n",
       "L 40.603125 10.7 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.211742 146.6 \n",
       "L 40.211742 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"m5cf3411486\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.211742\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-48\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(37.030492 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 79.350019 146.6 \n",
       "L 79.350019 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"79.350019\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 100 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-49\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(69.806269 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 118.488295 146.6 \n",
       "L 118.488295 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"118.488295\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 200 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-50\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(108.944545 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 157.626572 146.6 \n",
       "L 157.626572 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"157.626572\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 300 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-51\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(148.082822 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 196.764848 146.6 \n",
       "L 196.764848 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"196.764848\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 400 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-52\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(187.221098 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-52\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_6\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 235.903125 146.6 \n",
       "L 235.903125 10.7 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"235.903125\" xlink:href=\"#m5cf3411486\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- 500 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-53\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(226.359375 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "       <use x=\"127.246094\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_7\">\n",
       "     <!-- epoch -->\n",
       "     <defs>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-101\"/>\n",
       "      <path d=\"M 18.109375 8.203125 \n",
       "L 18.109375 -20.796875 \n",
       "L 9.078125 -20.796875 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.390625 \n",
       "Q 20.953125 51.265625 25.265625 53.625 \n",
       "Q 29.59375 56 35.59375 56 \n",
       "Q 45.5625 56 51.78125 48.09375 \n",
       "Q 58.015625 40.1875 58.015625 27.296875 \n",
       "Q 58.015625 14.40625 51.78125 6.484375 \n",
       "Q 45.5625 -1.421875 35.59375 -1.421875 \n",
       "Q 29.59375 -1.421875 25.265625 0.953125 \n",
       "Q 20.953125 3.328125 18.109375 8.203125 \n",
       "z\n",
       "M 48.6875 27.296875 \n",
       "Q 48.6875 37.203125 44.609375 42.84375 \n",
       "Q 40.53125 48.484375 33.40625 48.484375 \n",
       "Q 26.265625 48.484375 22.1875 42.84375 \n",
       "Q 18.109375 37.203125 18.109375 27.296875 \n",
       "Q 18.109375 17.390625 22.1875 11.75 \n",
       "Q 26.265625 6.109375 33.40625 6.109375 \n",
       "Q 40.53125 6.109375 44.609375 11.75 \n",
       "Q 48.6875 17.390625 48.6875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-112\"/>\n",
       "      <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-111\"/>\n",
       "      <path d=\"M 48.78125 52.59375 \n",
       "L 48.78125 44.1875 \n",
       "Q 44.96875 46.296875 41.140625 47.34375 \n",
       "Q 37.3125 48.390625 33.40625 48.390625 \n",
       "Q 24.65625 48.390625 19.8125 42.84375 \n",
       "Q 14.984375 37.3125 14.984375 27.296875 \n",
       "Q 14.984375 17.28125 19.8125 11.734375 \n",
       "Q 24.65625 6.203125 33.40625 6.203125 \n",
       "Q 37.3125 6.203125 41.140625 7.25 \n",
       "Q 44.96875 8.296875 48.78125 10.40625 \n",
       "L 48.78125 2.09375 \n",
       "Q 45.015625 0.34375 40.984375 -0.53125 \n",
       "Q 36.96875 -1.421875 32.421875 -1.421875 \n",
       "Q 20.0625 -1.421875 12.78125 6.34375 \n",
       "Q 5.515625 14.109375 5.515625 27.296875 \n",
       "Q 5.515625 40.671875 12.859375 48.328125 \n",
       "Q 20.21875 56 33.015625 56 \n",
       "Q 37.15625 56 41.109375 55.140625 \n",
       "Q 45.0625 54.296875 48.78125 52.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-99\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 75.984375 \n",
       "L 18.109375 75.984375 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-104\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(123.025 174.876562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"61.523438\" xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-111\"/>\n",
       "      <use x=\"186.181641\" xlink:href=\"#DejaVuSans-99\"/>\n",
       "      <use x=\"241.162109\" xlink:href=\"#DejaVuSans-104\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 145.980744 \n",
       "L 235.903125 145.980744 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"mc52d3d84fd\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"145.980744\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(27.240625 149.779963)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 119.863725 \n",
       "L 235.903125 119.863725 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"119.863725\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(27.240625 123.662944)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_17\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 93.746706 \n",
       "L 235.903125 93.746706 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_18\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"93.746706\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 10 -->\n",
       "      <g transform=\"translate(20.878125 97.545925)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_19\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 67.629688 \n",
       "L 235.903125 67.629688 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_20\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"67.629688\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_11\">\n",
       "      <!-- 15 -->\n",
       "      <g transform=\"translate(20.878125 71.428906)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_21\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 41.512669 \n",
       "L 235.903125 41.512669 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_22\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"41.512669\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_12\">\n",
       "      <!-- 20 -->\n",
       "      <g transform=\"translate(20.878125 45.311887)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-48\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_6\">\n",
       "     <g id=\"line2d_23\">\n",
       "      <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 15.39565 \n",
       "L 235.903125 15.39565 \n",
       "\" style=\"fill:none;stroke:#b0b0b0;stroke-linecap:square;stroke-width:0.8;\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_24\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"40.603125\" xlink:href=\"#mc52d3d84fd\" y=\"15.39565\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_13\">\n",
       "      <!-- 25 -->\n",
       "      <g transform=\"translate(20.878125 19.194869)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-50\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-53\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_14\">\n",
       "     <!-- perplexity -->\n",
       "     <defs>\n",
       "      <path d=\"M 41.109375 46.296875 \n",
       "Q 39.59375 47.171875 37.8125 47.578125 \n",
       "Q 36.03125 48 33.890625 48 \n",
       "Q 26.265625 48 22.1875 43.046875 \n",
       "Q 18.109375 38.09375 18.109375 28.8125 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 20.953125 51.171875 25.484375 53.578125 \n",
       "Q 30.03125 56 36.53125 56 \n",
       "Q 37.453125 56 38.578125 55.875 \n",
       "Q 39.703125 55.765625 41.0625 55.515625 \n",
       "z\n",
       "\" id=\"DejaVuSans-114\"/>\n",
       "      <path d=\"M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-108\"/>\n",
       "      <path d=\"M 54.890625 54.6875 \n",
       "L 35.109375 28.078125 \n",
       "L 55.90625 0 \n",
       "L 45.3125 0 \n",
       "L 29.390625 21.484375 \n",
       "L 13.484375 0 \n",
       "L 2.875 0 \n",
       "L 24.125 28.609375 \n",
       "L 4.6875 54.6875 \n",
       "L 15.28125 54.6875 \n",
       "L 29.78125 35.203125 \n",
       "L 44.28125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-120\"/>\n",
       "      <path d=\"M 9.421875 54.6875 \n",
       "L 18.40625 54.6875 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 64.59375 \n",
       "L 9.421875 64.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-105\"/>\n",
       "      <path d=\"M 18.3125 70.21875 \n",
       "L 18.3125 54.6875 \n",
       "L 36.8125 54.6875 \n",
       "L 36.8125 47.703125 \n",
       "L 18.3125 47.703125 \n",
       "L 18.3125 18.015625 \n",
       "Q 18.3125 11.328125 20.140625 9.421875 \n",
       "Q 21.96875 7.515625 27.59375 7.515625 \n",
       "L 36.8125 7.515625 \n",
       "L 36.8125 0 \n",
       "L 27.59375 0 \n",
       "Q 17.1875 0 13.234375 3.875 \n",
       "Q 9.28125 7.765625 9.28125 18.015625 \n",
       "L 9.28125 47.703125 \n",
       "L 2.6875 47.703125 \n",
       "L 2.6875 54.6875 \n",
       "L 9.28125 54.6875 \n",
       "L 9.28125 70.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-116\"/>\n",
       "      <path d=\"M 32.171875 -5.078125 \n",
       "Q 28.375 -14.84375 24.75 -17.8125 \n",
       "Q 21.140625 -20.796875 15.09375 -20.796875 \n",
       "L 7.90625 -20.796875 \n",
       "L 7.90625 -13.28125 \n",
       "L 13.1875 -13.28125 \n",
       "Q 16.890625 -13.28125 18.9375 -11.515625 \n",
       "Q 21 -9.765625 23.484375 -3.21875 \n",
       "L 25.09375 0.875 \n",
       "L 2.984375 54.6875 \n",
       "L 12.5 54.6875 \n",
       "L 29.59375 11.921875 \n",
       "L 46.6875 54.6875 \n",
       "L 56.203125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-121\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 103.863281)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"63.476562\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"166.113281\" xlink:href=\"#DejaVuSans-112\"/>\n",
       "      <use x=\"229.589844\" xlink:href=\"#DejaVuSans-108\"/>\n",
       "      <use x=\"257.373047\" xlink:href=\"#DejaVuSans-101\"/>\n",
       "      <use x=\"318.880859\" xlink:href=\"#DejaVuSans-120\"/>\n",
       "      <use x=\"378.060547\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"405.84375\" xlink:href=\"#DejaVuSans-116\"/>\n",
       "      <use x=\"445.052734\" xlink:href=\"#DejaVuSans-121\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_25\">\n",
       "    <path clip-path=\"url(#p2d05635d6e)\" d=\"M 40.603125 16.877273 \n",
       "L 44.516953 58.017499 \n",
       "L 48.43078 67.543258 \n",
       "L 52.344608 80.02564 \n",
       "L 56.258436 87.60321 \n",
       "L 60.172263 90.726198 \n",
       "L 64.086091 92.787675 \n",
       "L 67.999919 94.724372 \n",
       "L 71.913746 96.455282 \n",
       "L 75.827574 98.020203 \n",
       "L 79.741402 99.553068 \n",
       "L 83.655229 101.218895 \n",
       "L 87.569057 102.122641 \n",
       "L 91.482885 103.59944 \n",
       "L 95.396712 105.250708 \n",
       "L 99.31054 107.187072 \n",
       "L 103.224367 108.199909 \n",
       "L 107.138195 109.787169 \n",
       "L 111.052023 110.973534 \n",
       "L 114.96585 112.617236 \n",
       "L 118.879678 113.861909 \n",
       "L 122.793506 115.254709 \n",
       "L 126.707333 117.02731 \n",
       "L 130.621161 118.378704 \n",
       "L 134.534989 120.272677 \n",
       "L 138.448816 121.44771 \n",
       "L 142.362644 122.903616 \n",
       "L 146.276472 124.626471 \n",
       "L 150.190299 126.196966 \n",
       "L 154.104127 128.071095 \n",
       "L 158.017955 129.987054 \n",
       "L 161.931782 131.528309 \n",
       "L 165.84561 133.404803 \n",
       "L 169.759438 134.713539 \n",
       "L 173.673265 135.629887 \n",
       "L 177.587093 136.680527 \n",
       "L 181.500921 137.553344 \n",
       "L 185.414748 138.926919 \n",
       "L 189.328576 138.550258 \n",
       "L 193.242404 139.520999 \n",
       "L 197.156231 139.7727 \n",
       "L 201.070059 139.976741 \n",
       "L 204.983887 139.924966 \n",
       "L 208.897714 140.085881 \n",
       "L 212.811542 140.222088 \n",
       "L 216.725369 140.28063 \n",
       "L 220.639197 140.289936 \n",
       "L 224.553025 140.355959 \n",
       "L 228.466852 140.222397 \n",
       "L 232.38068 140.422727 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 40.603125 146.6 \n",
       "L 40.603125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 235.903125 146.6 \n",
       "L 235.903125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 40.603125 146.6 \n",
       "L 235.903125 146.6 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 40.603125 10.7 \n",
       "L 235.903125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 173.628125 33.378125 \n",
       "L 228.903125 33.378125 \n",
       "Q 230.903125 33.378125 230.903125 31.378125 \n",
       "L 230.903125 17.7 \n",
       "Q 230.903125 15.7 228.903125 15.7 \n",
       "L 173.628125 15.7 \n",
       "Q 171.628125 15.7 171.628125 17.7 \n",
       "L 171.628125 31.378125 \n",
       "Q 171.628125 33.378125 173.628125 33.378125 \n",
       "z\n",
       "\" style=\"fill:#ffffff;opacity:0.8;stroke:#cccccc;stroke-linejoin:miter;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_26\">\n",
       "     <path d=\"M 175.628125 23.798437 \n",
       "L 195.628125 23.798437 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_27\"/>\n",
       "    <g id=\"text_15\">\n",
       "     <!-- train -->\n",
       "     <defs>\n",
       "      <path d=\"M 34.28125 27.484375 \n",
       "Q 23.390625 27.484375 19.1875 25 \n",
       "Q 14.984375 22.515625 14.984375 16.5 \n",
       "Q 14.984375 11.71875 18.140625 8.90625 \n",
       "Q 21.296875 6.109375 26.703125 6.109375 \n",
       "Q 34.1875 6.109375 38.703125 11.40625 \n",
       "Q 43.21875 16.703125 43.21875 25.484375 \n",
       "L 43.21875 27.484375 \n",
       "z\n",
       "M 52.203125 31.203125 \n",
       "L 52.203125 0 \n",
       "L 43.21875 0 \n",
       "L 43.21875 8.296875 \n",
       "Q 40.140625 3.328125 35.546875 0.953125 \n",
       "Q 30.953125 -1.421875 24.3125 -1.421875 \n",
       "Q 15.921875 -1.421875 10.953125 3.296875 \n",
       "Q 6 8.015625 6 15.921875 \n",
       "Q 6 25.140625 12.171875 29.828125 \n",
       "Q 18.359375 34.515625 30.609375 34.515625 \n",
       "L 43.21875 34.515625 \n",
       "L 43.21875 35.40625 \n",
       "Q 43.21875 41.609375 39.140625 45 \n",
       "Q 35.0625 48.390625 27.6875 48.390625 \n",
       "Q 23 48.390625 18.546875 47.265625 \n",
       "Q 14.109375 46.140625 10.015625 43.890625 \n",
       "L 10.015625 52.203125 \n",
       "Q 14.9375 54.109375 19.578125 55.046875 \n",
       "Q 24.21875 56 28.609375 56 \n",
       "Q 40.484375 56 46.34375 49.84375 \n",
       "Q 52.203125 43.703125 52.203125 31.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-97\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-110\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(203.628125 27.298437)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-116\"/>\n",
       "      <use x=\"39.208984\" xlink:href=\"#DejaVuSans-114\"/>\n",
       "      <use x=\"80.322266\" xlink:href=\"#DejaVuSans-97\"/>\n",
       "      <use x=\"141.601562\" xlink:href=\"#DejaVuSans-105\"/>\n",
       "      <use x=\"169.384766\" xlink:href=\"#DejaVuSans-110\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p2d05635d6e\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"40.603125\" y=\"10.7\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "gru_layer = rnn.GRU(num_hiddens)\n",
    "model = d2l.RNNModel(gru_layer, len(vocab))\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs*10, ctx)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}