{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "RNN_fastai_pytorch.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyM3xZswsB9vdHZVmjN4Sx/z", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/bipinKrishnan/fastai_course/blob/master/RNN_fastai_pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "code", "metadata": { "id": "BR7oia6PUA9V" }, "source": [ "!pip install fastai --upgrade" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "2X4GNuc2UjwR" }, "source": [ "from fastai.text.all import *\n", "import torch\n", "import torch.nn.functional as F" ], "execution_count": 73, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_Q1Be8_8Uo5_", "outputId": "bd1c6cd2-56e1-47c8-ea98-777280d106b5", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "path = untar_data(URLs.HUMAN_NUMBERS)\n", "\n", "lines = L()\n", "with open(path/'train.txt') as f: lines += L(*f.readlines())\n", "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n", "lines" ], "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]" ] }, "metadata": { "tags": [] }, "execution_count": 3 } ] }, { "cell_type": "code", "metadata": { "id": "GL6ix4x1U6ux", "outputId": "ffcc3c65-c66e-461b-ea7e-19fcfe78515a", "colab": { "base_uri": "https://localhost:8080/", "height": 36 } }, "source": [ "text = ' . '.join([l.strip() for l in lines])\n", "text[:100]" ], "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'" ] }, "metadata": { "tags": [] }, "execution_count": 4 } ] }, { "cell_type": "code", "metadata": { "id": "3GGuy2TcVDWH", "outputId": "a59079e6-8eee-43d9-c5a8-cf8ba7d0a404", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "tokens = text.split(' ')\n", "tokens[:10]" ], "execution_count": 5, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']" ] }, "metadata": { "tags": [] }, "execution_count": 5 } ] }, { "cell_type": "code", "metadata": { "id": "IpQlXTP7VHuo", "outputId": "77a8d748-3259-482a-8b8b-bd68decbb261", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "vocab = L(*tokens).unique()\n", "vocab" ], "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]" ] }, "metadata": { "tags": [] }, "execution_count": 6 } ] }, { "cell_type": "code", "metadata": { "id": "4L7BUVZPVKdd", "outputId": "9f94b20c-967d-48bb-81a0-b46499919136", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "word2idx = {w:i for i,w in enumerate(vocab)}\n", "nums = L(word2idx[i] for i in tokens)\n", "nums" ], "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#63095) [0,1,2,1,3,1,4,1,5,1...]" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "code", "metadata": { "id": "Gv2SwVzEVQPp", "outputId": "a1242bfb-cf0f-4a23-e192-3e30c4397e97", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))" ], "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { "id": "hNqAIcZsVTaE", "outputId": "6cc728d6-9e49-40ce-bea2-2cf6c51f5410", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n", "seqs" ], "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { "id": "g5W1edqaVVZe" }, "source": [ "bs = 32\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)" ], "execution_count": 10, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "vJvw29iDYzXi" }, "source": [ "# 1st model" ] }, { "cell_type": "code", "metadata": { "id": "2-qg2u4NVpfB" }, "source": [ "class LMModel1(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", "\n", " def forward(self, x):\n", " h = F.relu(self.h_h(self.i_h(x[:, 0])))\n", " h = h + self.i_h(x[:, 1])\n", " h = F.relu(self.h_h(h))\n", " h = h + self.i_h(x[:, 2])\n", " h = F.relu(self.h_h(h))\n", " \n", " return self.h_o(h)" ], "execution_count": 12, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_x73KWo9X0K0", "outputId": "275a8980-1006-4663-c0c3-ea70715f32af", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "LMModel1(len(vocab), 3)(seqs[0][0].unsqueeze(0))" ], "execution_count": 18, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[-0.1084, -0.2783, -0.1223, -0.3544, -0.4660, -0.1827, -0.3859, 0.2552,\n", " -0.0022, -0.3058, 0.0698, -0.3552, 0.5132, -0.1932, 0.4968, 0.1891,\n", " 0.4565, -0.5759, -0.0737, -0.3763, 0.5565, 0.1311, 0.2966, 0.3392,\n", " -0.1113, 0.2586, 0.0560, -0.1836, 0.5182, -0.2767]],\n", " grad_fn=<AddmmBackward>)" ] }, "metadata": { "tags": [] }, "execution_count": 18 } ] }, { "cell_type": "code", "metadata": { "id": "t-7NOjprYSgj", "outputId": "352febb6-1a66-422d-9abe-65998e90e202", "colab": { "base_uri": "https://localhost:8080/", "height": 175 } }, "source": [ "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ], "execution_count": 21, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.789906</td>\n", " <td>2.101264</td>\n", " <td>0.443071</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.393594</td>\n", " <td>1.828861</td>\n", " <td>0.468029</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.410961</td>\n", " <td>1.673940</td>\n", " <td>0.492512</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.377922</td>\n", " <td>1.698437</td>\n", " <td>0.482529</td>\n", " <td>00:02</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "08_KlBMSY3FC" }, "source": [ "# 2nd model" ] }, { "cell_type": "code", "metadata": { "id": "Y-AODj2xYpCh" }, "source": [ "class LMModel2(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", "\n", " def forward(self, x):\n", " h = 0\n", " for i in range(3):\n", " h = h + self.i_h(x[:, i])\n", " h = F.relu(self.h_h(h))\n", "\n", " return self.h_o(h)" ], "execution_count": 24, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Tril0ceNZit9", "outputId": "f4dedbc2-3c68-4308-81cc-ccd57fe78741", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "LMModel2(len(vocab), 3)(seqs[0][0].unsqueeze(0))" ], "execution_count": 25, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0.0557, -0.3203, -0.8181, -0.6636, 0.1325, 0.0072, -0.0247, 0.6777,\n", " 0.1840, -0.4547, -0.0569, 0.2728, -0.2493, -0.0508, 0.1517, 0.5898,\n", " 0.2106, -0.7529, 0.9692, 0.6834, -0.4625, -0.0441, -0.2372, 0.4435,\n", " -0.0727, 0.1358, -0.0039, 0.1496, -0.6447, 0.3787]],\n", " grad_fn=<AddmmBackward>)" ] }, "metadata": { "tags": [] }, "execution_count": 25 } ] }, { "cell_type": "code", "metadata": { "id": "6aqtRIf4Zl7M", "outputId": "02654e55-17cd-4d5e-866e-448d6b18b328", "colab": { "base_uri": "https://localhost:8080/", "height": 175 } }, "source": [ "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ], "execution_count": 26, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.909757</td>\n", " <td>1.930985</td>\n", " <td>0.478013</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.469091</td>\n", " <td>1.714454</td>\n", " <td>0.479677</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.429843</td>\n", " <td>1.672472</td>\n", " <td>0.492988</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.390096</td>\n", " <td>1.681135</td>\n", " <td>0.465890</td>\n", " <td>00:02</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "K1KBmyl3bF9x" }, "source": [ "# 3rd model" ] }, { "cell_type": "code", "metadata": { "id": "plOzFV2XZuw0" }, "source": [ "class LMModel3(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h = 0\n", "\n", " def forward(self, x):\n", " for i in range(3):\n", " self.h = self.h + self.i_h(x[:, i])\n", " self.h = F.relu(self.h_h(self.h))\n", " out = self.h_o(self.h)\n", " self.h = self.h.detach()\n", " return out\n", "\n", " def reset(self): self.h = 0" ], "execution_count": 35, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "a62gbcJebchw", "outputId": "d2d9a150-ea8c-4d22-dc94-70c186eb25e3", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "LMModel3(len(vocab), 3)(seqs[0][0].unsqueeze(0))" ], "execution_count": 36, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[-0.3256, -0.2493, -0.0404, 0.2820, 0.3655, 0.1891, -0.1301, -0.1874,\n", " -0.4694, -0.3767, -0.0546, -0.0693, -0.4469, 0.0875, 0.0070, 0.0787,\n", " 0.0223, -0.0287, 0.5465, -0.0721, -0.4811, -0.3768, 0.5216, -0.4914,\n", " 0.0082, 0.3935, -0.5356, -0.4153, 0.2180, 0.1427]],\n", " grad_fn=<AddmmBackward>)" ] }, "metadata": { "tags": [] }, "execution_count": 36 } ] }, { "cell_type": "code", "metadata": { "id": "_RfM7xrSdaxa", "outputId": "f0d0d0fd-8239-427e-d7db-60fe64959d21", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "m = len(seqs)//bs\n", "m, bs, len(seqs)" ], "execution_count": 42, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(657, 32, 21031)" ] }, "metadata": { "tags": [] }, "execution_count": 42 } ] }, { "cell_type": "code", "metadata": { "id": "3L6j0lqmdqAt" }, "source": [ "def group_chunks(ds, bs):\n", " m = len(ds) // bs\n", " new_ds = L()\n", " for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n", " return new_ds" ], "execution_count": 56, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KbQ9WwZhef0a", "outputId": "226fb109-d0c5-4df8-e702-6ea6025c5728", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "s = seqs[:21024]\n", "s" ], "execution_count": 52, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(#21024) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]" ] }, "metadata": { "tags": [] }, "execution_count": 52 } ] }, { "cell_type": "code", "metadata": { "id": "aUfmSHmqerkR" }, "source": [ "cut = int(len(seqs) * 0.8)\n", "d = DataLoaders.from_dsets(s[:cut], s[cut:], bs=64, shuffle=False)" ], "execution_count": 53, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "PbeB90deflEm" }, "source": [ "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), group_chunks(seqs[cut:], 64), bs=64, shuffle=False)" ], "execution_count": 57, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "q4rbQ8IpfFQs", "outputId": "e78a99fa-592f-4a01-e895-a95388e2c94b", "colab": { "base_uri": "https://localhost:8080/", "height": 175 } }, "source": [ "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(4, 1e-3)" ], "execution_count": 58, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>1.664295</td>\n", " <td>1.806631</td>\n", " <td>0.486058</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.342835</td>\n", " <td>1.807223</td>\n", " <td>0.417308</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.218265</td>\n", " <td>1.683006</td>\n", " <td>0.458654</td>\n", " <td>00:02</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.168428</td>\n", " <td>1.693239</td>\n", " <td>0.459615</td>\n", " <td>00:02</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "lqzW4QwxjUt0" }, "source": [ "# 4th model" ] }, { "cell_type": "code", "metadata": { "id": "-xiXJ4F1gpRO" }, "source": [ "sl = 16\n", "seqs = L((tensor(nums[i: i+sl]), tensor(nums[i+1: i+sl+1])) for i in range(0, len(nums), sl))\n", "\n", "cut = int(len(seqs)*0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), \n", " group_chunks(seqs[cut:], 64),\n", " bs=bs, drop_last=True, shuffle=False)" ], "execution_count": 71, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "zGGEgMlzhmuv" }, "source": [ "class LMModel4(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.h_h = nn.Linear(n_hidden, n_hidden)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h = 0\n", "\n", " def forward(self, x):\n", " outs = []\n", " for i in range(sl):\n", " self.h = self.h + self.i_h(x[:, i])\n", " self.h = F.relu(self.h_h(self.h))\n", " outs.append(self.h_o(self.h))\n", " self.h = self.h.detach()\n", "\n", " return torch.stack(outs, dim=1)\n", "\n", " def reset(self): self.h = 0" ], "execution_count": 84, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "iuCDgnJRj7KF" }, "source": [ "LMModel4(len(vocab), 64)(seqs[0][0].unsqueeze(0))" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "qkDfICYKkDOJ" }, "source": [ "def loss(inp, target): return F.cross_entropy(inp.view(-1, len(vocab)), target.view(-1))" ], "execution_count": 86, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "v1RAVcRukENb", "outputId": "15b83361-87b6-4a4d-b9dc-c4ccedf771b3", "colab": { "base_uri": "https://localhost:8080/", "height": 175 } }, "source": [ "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss, metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(4, 1e-3)" ], "execution_count": 88, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>2.511702</td>\n", " <td>2.013124</td>\n", " <td>0.420410</td>\n", " <td>00:01</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>1.723667</td>\n", " <td>2.002416</td>\n", " <td>0.376790</td>\n", " <td>00:01</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1.575574</td>\n", " <td>2.009366</td>\n", " <td>0.364746</td>\n", " <td>00:01</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>1.538547</td>\n", " <td>2.018201</td>\n", " <td>0.364909</td>\n", " <td>00:01</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "markdown", "metadata": { "id": "bDFlajSKod-S" }, "source": [ "# 5th model" ] }, { "cell_type": "code", "metadata": { "id": "jMsoZuFWlS_n" }, "source": [ "class LMModel5(Module):\n", " def __init__(self, vocab_sz, n_layers, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.rnn = nn.RNN(n_hidden, 32, n_layers, batch_first=True)\n", " self.h_o = nn.Linear(32, vocab_sz)\n", " self.h = torch.zeros(n_layers, 32, 32)\n", "\n", " def forward(self, x):\n", " result, h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(result)\n", "\n", " def reset(self): self.h.zero_()" ], "execution_count": 107, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DWcF3_4tpm6R", "outputId": "eabb2e4a-6b03-4cfb-af36-e0b369483631", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "nn.Embedding(13, 4)(torch.tensor([0, 1, 4]))" ], "execution_count": 96, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[ 0.1739, 2.3553, -0.0054, 0.5570],\n", " [ 1.3818, 0.1635, -0.3897, -0.7589],\n", " [-1.4636, 0.6106, -1.7279, 0.9655]], grad_fn=<EmbeddingBackward>)" ] }, "metadata": { "tags": [] }, "execution_count": 96 } ] }, { "cell_type": "code", "metadata": { "id": "VhFyUiXupp0y", "outputId": "69a9b78c-96b2-4478-a587-100758a74cec", "colab": { "base_uri": "https://localhost:8080/", "height": 175 } }, "source": [ "learn = Learner(dls, LMModel5(len(vocab), 64, 3), \n", " loss_func=CrossEntropyLossFlat(), \n", " metrics=accuracy, cbs=ModelResetter)\n", "learn.fit_one_cycle(4, 3e-3)" ], "execution_count": 108, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>valid_loss</th>\n", " <th>accuracy</th>\n", " <th>time</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>0</td>\n", " <td>2.847264</td>\n", " <td>2.796035</td>\n", " <td>0.151855</td>\n", " <td>00:14</td>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>2.757411</td>\n", " <td>2.792833</td>\n", " <td>0.151855</td>\n", " <td>00:14</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>2.745373</td>\n", " <td>2.805568</td>\n", " <td>0.151855</td>\n", " <td>00:14</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>2.743357</td>\n", " <td>2.806942</td>\n", " <td>0.151855</td>\n", " <td>00:14</td>\n", " </tr>\n", " </tbody>\n", "</table>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "cGZ6OmTtqjUL" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }