{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "RNN_fastai_pytorch.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyM3xZswsB9vdHZVmjN4Sx/z",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/bipinKrishnan/fastai_course/blob/master/RNN_fastai_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BR7oia6PUA9V"
      },
      "source": [
        "!pip install fastai --upgrade"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2X4GNuc2UjwR"
      },
      "source": [
        "from fastai.text.all import *\n",
        "import torch\n",
        "import torch.nn.functional as F"
      ],
      "execution_count": 73,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_Q1Be8_8Uo5_",
        "outputId": "bd1c6cd2-56e1-47c8-ea98-777280d106b5",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "path = untar_data(URLs.HUMAN_NUMBERS)\n",
        "\n",
        "lines = L()\n",
        "with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
        "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
        "lines"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GL6ix4x1U6ux",
        "outputId": "ffcc3c65-c66e-461b-ea7e-19fcfe78515a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 36
        }
      },
      "source": [
        "text = ' . '.join([l.strip() for l in lines])\n",
        "text[:100]"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3GGuy2TcVDWH",
        "outputId": "a59079e6-8eee-43d9-c5a8-cf8ba7d0a404",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "tokens = text.split(' ')\n",
        "tokens[:10]"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IpQlXTP7VHuo",
        "outputId": "77a8d748-3259-482a-8b8b-bd68decbb261",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "vocab = L(*tokens).unique()\n",
        "vocab"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4L7BUVZPVKdd",
        "outputId": "9f94b20c-967d-48bb-81a0-b46499919136",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "word2idx = {w:i for i,w in enumerate(vocab)}\n",
        "nums = L(word2idx[i] for i in tokens)\n",
        "nums"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#63095) [0,1,2,1,3,1,4,1,5,1...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Gv2SwVzEVQPp",
        "outputId": "a1242bfb-cf0f-4a23-e192-3e30c4397e97",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hNqAIcZsVTaE",
        "outputId": "6cc728d6-9e49-40ce-bea2-2cf6c51f5410",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
        "seqs"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g5W1edqaVVZe"
      },
      "source": [
        "bs = 32\n",
        "cut = int(len(seqs) * 0.8)\n",
        "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vJvw29iDYzXi"
      },
      "source": [
        "# 1st model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2-qg2u4NVpfB"
      },
      "source": [
        "class LMModel1(Module):\n",
        "  def __init__(self, vocab_sz, n_hidden):\n",
        "    self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
        "    self.h_h = nn.Linear(n_hidden, n_hidden)\n",
        "    self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
        "\n",
        "  def forward(self, x):\n",
        "    h = F.relu(self.h_h(self.i_h(x[:, 0])))\n",
        "    h = h + self.i_h(x[:, 1])\n",
        "    h = F.relu(self.h_h(h))\n",
        "    h = h + self.i_h(x[:, 2])\n",
        "    h = F.relu(self.h_h(h))\n",
        "    \n",
        "    return self.h_o(h)"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_x73KWo9X0K0",
        "outputId": "275a8980-1006-4663-c0c3-ea70715f32af",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "LMModel1(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-0.1084, -0.2783, -0.1223, -0.3544, -0.4660, -0.1827, -0.3859,  0.2552,\n",
              "         -0.0022, -0.3058,  0.0698, -0.3552,  0.5132, -0.1932,  0.4968,  0.1891,\n",
              "          0.4565, -0.5759, -0.0737, -0.3763,  0.5565,  0.1311,  0.2966,  0.3392,\n",
              "         -0.1113,  0.2586,  0.0560, -0.1836,  0.5182, -0.2767]],\n",
              "       grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t-7NOjprYSgj",
        "outputId": "352febb6-1a66-422d-9abe-65998e90e202",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        }
      },
      "source": [
        "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
        "learn.fit_one_cycle(4, 1e-3)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.789906</td>\n",
              "      <td>2.101264</td>\n",
              "      <td>0.443071</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.393594</td>\n",
              "      <td>1.828861</td>\n",
              "      <td>0.468029</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.410961</td>\n",
              "      <td>1.673940</td>\n",
              "      <td>0.492512</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.377922</td>\n",
              "      <td>1.698437</td>\n",
              "      <td>0.482529</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "08_KlBMSY3FC"
      },
      "source": [
        "# 2nd model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y-AODj2xYpCh"
      },
      "source": [
        "class LMModel2(Module):\n",
        "  def __init__(self, vocab_sz, n_hidden):\n",
        "    self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
        "    self.h_h = nn.Linear(n_hidden, n_hidden)\n",
        "    self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
        "\n",
        "  def forward(self, x):\n",
        "    h = 0\n",
        "    for i in range(3):\n",
        "      h = h + self.i_h(x[:, i])\n",
        "      h = F.relu(self.h_h(h))\n",
        "\n",
        "    return self.h_o(h)"
      ],
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tril0ceNZit9",
        "outputId": "f4dedbc2-3c68-4308-81cc-ccd57fe78741",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "LMModel2(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 0.0557, -0.3203, -0.8181, -0.6636,  0.1325,  0.0072, -0.0247,  0.6777,\n",
              "          0.1840, -0.4547, -0.0569,  0.2728, -0.2493, -0.0508,  0.1517,  0.5898,\n",
              "          0.2106, -0.7529,  0.9692,  0.6834, -0.4625, -0.0441, -0.2372,  0.4435,\n",
              "         -0.0727,  0.1358, -0.0039,  0.1496, -0.6447,  0.3787]],\n",
              "       grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6aqtRIf4Zl7M",
        "outputId": "02654e55-17cd-4d5e-866e-448d6b18b328",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        }
      },
      "source": [
        "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
        "learn.fit_one_cycle(4, 1e-3)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.909757</td>\n",
              "      <td>1.930985</td>\n",
              "      <td>0.478013</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.469091</td>\n",
              "      <td>1.714454</td>\n",
              "      <td>0.479677</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.429843</td>\n",
              "      <td>1.672472</td>\n",
              "      <td>0.492988</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.390096</td>\n",
              "      <td>1.681135</td>\n",
              "      <td>0.465890</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K1KBmyl3bF9x"
      },
      "source": [
        "# 3rd model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "plOzFV2XZuw0"
      },
      "source": [
        "class LMModel3(Module):\n",
        "  def __init__(self, vocab_sz, n_hidden):\n",
        "    self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
        "    self.h_h = nn.Linear(n_hidden, n_hidden)\n",
        "    self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
        "    self.h = 0\n",
        "\n",
        "  def forward(self, x):\n",
        "    for i in range(3):\n",
        "      self.h = self.h + self.i_h(x[:, i])\n",
        "      self.h = F.relu(self.h_h(self.h))\n",
        "    out = self.h_o(self.h)\n",
        "    self.h = self.h.detach()\n",
        "    return out\n",
        "\n",
        "  def reset(self): self.h = 0"
      ],
      "execution_count": 35,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a62gbcJebchw",
        "outputId": "d2d9a150-ea8c-4d22-dc94-70c186eb25e3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "LMModel3(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-0.3256, -0.2493, -0.0404,  0.2820,  0.3655,  0.1891, -0.1301, -0.1874,\n",
              "         -0.4694, -0.3767, -0.0546, -0.0693, -0.4469,  0.0875,  0.0070,  0.0787,\n",
              "          0.0223, -0.0287,  0.5465, -0.0721, -0.4811, -0.3768,  0.5216, -0.4914,\n",
              "          0.0082,  0.3935, -0.5356, -0.4153,  0.2180,  0.1427]],\n",
              "       grad_fn=<AddmmBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 36
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_RfM7xrSdaxa",
        "outputId": "f0d0d0fd-8239-427e-d7db-60fe64959d21",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "m = len(seqs)//bs\n",
        "m, bs, len(seqs)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(657, 32, 21031)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 42
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3L6j0lqmdqAt"
      },
      "source": [
        "def group_chunks(ds, bs):\n",
        "    m = len(ds) // bs\n",
        "    new_ds = L()\n",
        "    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
        "    return new_ds"
      ],
      "execution_count": 56,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KbQ9WwZhef0a",
        "outputId": "226fb109-d0c5-4df8-e702-6ea6025c5728",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "s = seqs[:21024]\n",
        "s"
      ],
      "execution_count": 52,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#21024) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 52
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aUfmSHmqerkR"
      },
      "source": [
        "cut = int(len(seqs) * 0.8)\n",
        "d = DataLoaders.from_dsets(s[:cut], s[cut:], bs=64, shuffle=False)"
      ],
      "execution_count": 53,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PbeB90deflEm"
      },
      "source": [
        "cut = int(len(seqs) * 0.8)\n",
        "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), group_chunks(seqs[cut:], 64), bs=64, shuffle=False)"
      ],
      "execution_count": 57,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q4rbQ8IpfFQs",
        "outputId": "e78a99fa-592f-4a01-e895-a95388e2c94b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        }
      },
      "source": [
        "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy, cbs=ModelResetter)\n",
        "learn.fit_one_cycle(4, 1e-3)"
      ],
      "execution_count": 58,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.664295</td>\n",
              "      <td>1.806631</td>\n",
              "      <td>0.486058</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.342835</td>\n",
              "      <td>1.807223</td>\n",
              "      <td>0.417308</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.218265</td>\n",
              "      <td>1.683006</td>\n",
              "      <td>0.458654</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.168428</td>\n",
              "      <td>1.693239</td>\n",
              "      <td>0.459615</td>\n",
              "      <td>00:02</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lqzW4QwxjUt0"
      },
      "source": [
        "# 4th model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-xiXJ4F1gpRO"
      },
      "source": [
        "sl = 16\n",
        "seqs = L((tensor(nums[i: i+sl]), tensor(nums[i+1: i+sl+1])) for i in range(0, len(nums), sl))\n",
        "\n",
        "cut = int(len(seqs)*0.8)\n",
        "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), \n",
        "                             group_chunks(seqs[cut:], 64),\n",
        "                             bs=bs, drop_last=True, shuffle=False)"
      ],
      "execution_count": 71,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zGGEgMlzhmuv"
      },
      "source": [
        "class LMModel4(Module):\n",
        "  def __init__(self, vocab_sz, n_hidden):\n",
        "    self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
        "    self.h_h = nn.Linear(n_hidden, n_hidden)\n",
        "    self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
        "    self.h = 0\n",
        "\n",
        "  def forward(self, x):\n",
        "    outs = []\n",
        "    for i in range(sl):\n",
        "      self.h = self.h + self.i_h(x[:, i])\n",
        "      self.h = F.relu(self.h_h(self.h))\n",
        "      outs.append(self.h_o(self.h))\n",
        "    self.h = self.h.detach()\n",
        "\n",
        "    return torch.stack(outs, dim=1)\n",
        "\n",
        "  def reset(self): self.h = 0"
      ],
      "execution_count": 84,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iuCDgnJRj7KF"
      },
      "source": [
        "LMModel4(len(vocab), 64)(seqs[0][0].unsqueeze(0))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qkDfICYKkDOJ"
      },
      "source": [
        "def loss(inp, target): return F.cross_entropy(inp.view(-1, len(vocab)), target.view(-1))"
      ],
      "execution_count": 86,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v1RAVcRukENb",
        "outputId": "15b83361-87b6-4a4d-b9dc-c4ccedf771b3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        }
      },
      "source": [
        "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss, metrics=accuracy, cbs=ModelResetter)\n",
        "learn.fit_one_cycle(4, 1e-3)"
      ],
      "execution_count": 88,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>2.511702</td>\n",
              "      <td>2.013124</td>\n",
              "      <td>0.420410</td>\n",
              "      <td>00:01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.723667</td>\n",
              "      <td>2.002416</td>\n",
              "      <td>0.376790</td>\n",
              "      <td>00:01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.575574</td>\n",
              "      <td>2.009366</td>\n",
              "      <td>0.364746</td>\n",
              "      <td>00:01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.538547</td>\n",
              "      <td>2.018201</td>\n",
              "      <td>0.364909</td>\n",
              "      <td>00:01</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bDFlajSKod-S"
      },
      "source": [
        "# 5th model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jMsoZuFWlS_n"
      },
      "source": [
        "class LMModel5(Module):\n",
        "  def __init__(self, vocab_sz, n_layers, n_hidden):\n",
        "    self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
        "    self.rnn = nn.RNN(n_hidden, 32, n_layers, batch_first=True)\n",
        "    self.h_o = nn.Linear(32, vocab_sz)\n",
        "    self.h = torch.zeros(n_layers, 32, 32)\n",
        "\n",
        "  def forward(self, x):\n",
        "    result, h = self.rnn(self.i_h(x), self.h)\n",
        "    self.h = h.detach()\n",
        "    return self.h_o(result)\n",
        "\n",
        "  def reset(self): self.h.zero_()"
      ],
      "execution_count": 107,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DWcF3_4tpm6R",
        "outputId": "eabb2e4a-6b03-4cfb-af36-e0b369483631",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "nn.Embedding(13, 4)(torch.tensor([0, 1, 4]))"
      ],
      "execution_count": 96,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 0.1739,  2.3553, -0.0054,  0.5570],\n",
              "        [ 1.3818,  0.1635, -0.3897, -0.7589],\n",
              "        [-1.4636,  0.6106, -1.7279,  0.9655]], grad_fn=<EmbeddingBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 96
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VhFyUiXupp0y",
        "outputId": "69a9b78c-96b2-4478-a587-100758a74cec",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        }
      },
      "source": [
        "learn = Learner(dls, LMModel5(len(vocab), 64, 3), \n",
        "                loss_func=CrossEntropyLossFlat(), \n",
        "                metrics=accuracy, cbs=ModelResetter)\n",
        "learn.fit_one_cycle(4, 3e-3)"
      ],
      "execution_count": 108,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>2.847264</td>\n",
              "      <td>2.796035</td>\n",
              "      <td>0.151855</td>\n",
              "      <td>00:14</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>2.757411</td>\n",
              "      <td>2.792833</td>\n",
              "      <td>0.151855</td>\n",
              "      <td>00:14</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>2.745373</td>\n",
              "      <td>2.805568</td>\n",
              "      <td>0.151855</td>\n",
              "      <td>00:14</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>2.743357</td>\n",
              "      <td>2.806942</td>\n",
              "      <td>0.151855</td>\n",
              "      <td>00:14</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cGZ6OmTtqjUL"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}