{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "RNN_fastai_pytorch.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyM3xZswsB9vdHZVmjN4Sx/z",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BR7oia6PUA9V"
},
"source": [
"!pip install fastai --upgrade"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2X4GNuc2UjwR"
},
"source": [
"from fastai.text.all import *\n",
"import torch\n",
"import torch.nn.functional as F"
],
"execution_count": 73,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_Q1Be8_8Uo5_",
"outputId": "bd1c6cd2-56e1-47c8-ea98-777280d106b5",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
"\n",
"lines = L()\n",
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
"lines"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "GL6ix4x1U6ux",
"outputId": "ffcc3c65-c66e-461b-ea7e-19fcfe78515a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 36
}
},
"source": [
"text = ' . '.join([l.strip() for l in lines])\n",
"text[:100]"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3GGuy2TcVDWH",
"outputId": "a59079e6-8eee-43d9-c5a8-cf8ba7d0a404",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"tokens = text.split(' ')\n",
"tokens[:10]"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IpQlXTP7VHuo",
"outputId": "77a8d748-3259-482a-8b8b-bd68decbb261",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"vocab = L(*tokens).unique()\n",
"vocab"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4L7BUVZPVKdd",
"outputId": "9f94b20c-967d-48bb-81a0-b46499919136",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
"nums = L(word2idx[i] for i in tokens)\n",
"nums"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Gv2SwVzEVQPp",
"outputId": "a1242bfb-cf0f-4a23-e192-3e30c4397e97",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hNqAIcZsVTaE",
"outputId": "6cc728d6-9e49-40ce-bea2-2cf6c51f5410",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
"seqs"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "g5W1edqaVVZe"
},
"source": [
"bs = 32\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vJvw29iDYzXi"
},
"source": [
"# 1st model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2-qg2u4NVpfB"
},
"source": [
"class LMModel1(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
"\n",
" def forward(self, x):\n",
" h = F.relu(self.h_h(self.i_h(x[:, 0])))\n",
" h = h + self.i_h(x[:, 1])\n",
" h = F.relu(self.h_h(h))\n",
" h = h + self.i_h(x[:, 2])\n",
" h = F.relu(self.h_h(h))\n",
" \n",
" return self.h_o(h)"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_x73KWo9X0K0",
"outputId": "275a8980-1006-4663-c0c3-ea70715f32af",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"LMModel1(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
],
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.1084, -0.2783, -0.1223, -0.3544, -0.4660, -0.1827, -0.3859, 0.2552,\n",
" -0.0022, -0.3058, 0.0698, -0.3552, 0.5132, -0.1932, 0.4968, 0.1891,\n",
" 0.4565, -0.5759, -0.0737, -0.3763, 0.5565, 0.1311, 0.2966, 0.3392,\n",
" -0.1113, 0.2586, 0.0560, -0.1836, 0.5182, -0.2767]],\n",
" grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "t-7NOjprYSgj",
"outputId": "352febb6-1a66-422d-9abe-65998e90e202",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
],
"execution_count": 21,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.789906 | \n",
" 2.101264 | \n",
" 0.443071 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.393594 | \n",
" 1.828861 | \n",
" 0.468029 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.410961 | \n",
" 1.673940 | \n",
" 0.492512 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.377922 | \n",
" 1.698437 | \n",
" 0.482529 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "08_KlBMSY3FC"
},
"source": [
"# 2nd model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y-AODj2xYpCh"
},
"source": [
"class LMModel2(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
"\n",
" def forward(self, x):\n",
" h = 0\n",
" for i in range(3):\n",
" h = h + self.i_h(x[:, i])\n",
" h = F.relu(self.h_h(h))\n",
"\n",
" return self.h_o(h)"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Tril0ceNZit9",
"outputId": "f4dedbc2-3c68-4308-81cc-ccd57fe78741",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"LMModel2(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
],
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.0557, -0.3203, -0.8181, -0.6636, 0.1325, 0.0072, -0.0247, 0.6777,\n",
" 0.1840, -0.4547, -0.0569, 0.2728, -0.2493, -0.0508, 0.1517, 0.5898,\n",
" 0.2106, -0.7529, 0.9692, 0.6834, -0.4625, -0.0441, -0.2372, 0.4435,\n",
" -0.0727, 0.1358, -0.0039, 0.1496, -0.6447, 0.3787]],\n",
" grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6aqtRIf4Zl7M",
"outputId": "02654e55-17cd-4d5e-866e-448d6b18b328",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
],
"execution_count": 26,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.909757 | \n",
" 1.930985 | \n",
" 0.478013 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.469091 | \n",
" 1.714454 | \n",
" 0.479677 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.429843 | \n",
" 1.672472 | \n",
" 0.492988 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.390096 | \n",
" 1.681135 | \n",
" 0.465890 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K1KBmyl3bF9x"
},
"source": [
"# 3rd model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "plOzFV2XZuw0"
},
"source": [
"class LMModel3(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = 0\n",
"\n",
" def forward(self, x):\n",
" for i in range(3):\n",
" self.h = self.h + self.i_h(x[:, i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" out = self.h_o(self.h)\n",
" self.h = self.h.detach()\n",
" return out\n",
"\n",
" def reset(self): self.h = 0"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "a62gbcJebchw",
"outputId": "d2d9a150-ea8c-4d22-dc94-70c186eb25e3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"LMModel3(len(vocab), 3)(seqs[0][0].unsqueeze(0))"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-0.3256, -0.2493, -0.0404, 0.2820, 0.3655, 0.1891, -0.1301, -0.1874,\n",
" -0.4694, -0.3767, -0.0546, -0.0693, -0.4469, 0.0875, 0.0070, 0.0787,\n",
" 0.0223, -0.0287, 0.5465, -0.0721, -0.4811, -0.3768, 0.5216, -0.4914,\n",
" 0.0082, 0.3935, -0.5356, -0.4153, 0.2180, 0.1427]],\n",
" grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_RfM7xrSdaxa",
"outputId": "f0d0d0fd-8239-427e-d7db-60fe64959d21",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"m = len(seqs)//bs\n",
"m, bs, len(seqs)"
],
"execution_count": 42,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(657, 32, 21031)"
]
},
"metadata": {
"tags": []
},
"execution_count": 42
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3L6j0lqmdqAt"
},
"source": [
"def group_chunks(ds, bs):\n",
" m = len(ds) // bs\n",
" new_ds = L()\n",
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
" return new_ds"
],
"execution_count": 56,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KbQ9WwZhef0a",
"outputId": "226fb109-d0c5-4df8-e702-6ea6025c5728",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"s = seqs[:21024]\n",
"s"
],
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#21024) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
]
},
"metadata": {
"tags": []
},
"execution_count": 52
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aUfmSHmqerkR"
},
"source": [
"cut = int(len(seqs) * 0.8)\n",
"d = DataLoaders.from_dsets(s[:cut], s[cut:], bs=64, shuffle=False)"
],
"execution_count": 53,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PbeB90deflEm"
},
"source": [
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), group_chunks(seqs[cut:], 64), bs=64, shuffle=False)"
],
"execution_count": 57,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "q4rbQ8IpfFQs",
"outputId": "e78a99fa-592f-4a01-e895-a95388e2c94b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy, cbs=ModelResetter)\n",
"learn.fit_one_cycle(4, 1e-3)"
],
"execution_count": 58,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.664295 | \n",
" 1.806631 | \n",
" 0.486058 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.342835 | \n",
" 1.807223 | \n",
" 0.417308 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.218265 | \n",
" 1.683006 | \n",
" 0.458654 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.168428 | \n",
" 1.693239 | \n",
" 0.459615 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqzW4QwxjUt0"
},
"source": [
"# 4th model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-xiXJ4F1gpRO"
},
"source": [
"sl = 16\n",
"seqs = L((tensor(nums[i: i+sl]), tensor(nums[i+1: i+sl+1])) for i in range(0, len(nums), sl))\n",
"\n",
"cut = int(len(seqs)*0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], 64), \n",
" group_chunks(seqs[cut:], 64),\n",
" bs=bs, drop_last=True, shuffle=False)"
],
"execution_count": 71,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zGGEgMlzhmuv"
},
"source": [
"class LMModel4(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.h_h = nn.Linear(n_hidden, n_hidden)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = 0\n",
"\n",
" def forward(self, x):\n",
" outs = []\n",
" for i in range(sl):\n",
" self.h = self.h + self.i_h(x[:, i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" outs.append(self.h_o(self.h))\n",
" self.h = self.h.detach()\n",
"\n",
" return torch.stack(outs, dim=1)\n",
"\n",
" def reset(self): self.h = 0"
],
"execution_count": 84,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iuCDgnJRj7KF"
},
"source": [
"LMModel4(len(vocab), 64)(seqs[0][0].unsqueeze(0))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qkDfICYKkDOJ"
},
"source": [
"def loss(inp, target): return F.cross_entropy(inp.view(-1, len(vocab)), target.view(-1))"
],
"execution_count": 86,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "v1RAVcRukENb",
"outputId": "15b83361-87b6-4a4d-b9dc-c4ccedf771b3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss, metrics=accuracy, cbs=ModelResetter)\n",
"learn.fit_one_cycle(4, 1e-3)"
],
"execution_count": 88,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 2.511702 | \n",
" 2.013124 | \n",
" 0.420410 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.723667 | \n",
" 2.002416 | \n",
" 0.376790 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.575574 | \n",
" 2.009366 | \n",
" 0.364746 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.538547 | \n",
" 2.018201 | \n",
" 0.364909 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bDFlajSKod-S"
},
"source": [
"# 5th model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jMsoZuFWlS_n"
},
"source": [
"class LMModel5(Module):\n",
" def __init__(self, vocab_sz, n_layers, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.RNN(n_hidden, 32, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(32, vocab_sz)\n",
" self.h = torch.zeros(n_layers, 32, 32)\n",
"\n",
" def forward(self, x):\n",
" result, h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(result)\n",
"\n",
" def reset(self): self.h.zero_()"
],
"execution_count": 107,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DWcF3_4tpm6R",
"outputId": "eabb2e4a-6b03-4cfb-af36-e0b369483631",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"nn.Embedding(13, 4)(torch.tensor([0, 1, 4]))"
],
"execution_count": 96,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[ 0.1739, 2.3553, -0.0054, 0.5570],\n",
" [ 1.3818, 0.1635, -0.3897, -0.7589],\n",
" [-1.4636, 0.6106, -1.7279, 0.9655]], grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 96
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VhFyUiXupp0y",
"outputId": "69a9b78c-96b2-4478-a587-100758a74cec",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"learn = Learner(dls, LMModel5(len(vocab), 64, 3), \n",
" loss_func=CrossEntropyLossFlat(), \n",
" metrics=accuracy, cbs=ModelResetter)\n",
"learn.fit_one_cycle(4, 3e-3)"
],
"execution_count": 108,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 2.847264 | \n",
" 2.796035 | \n",
" 0.151855 | \n",
" 00:14 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.757411 | \n",
" 2.792833 | \n",
" 0.151855 | \n",
" 00:14 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.745373 | \n",
" 2.805568 | \n",
" 0.151855 | \n",
" 00:14 | \n",
"
\n",
" \n",
" | 3 | \n",
" 2.743357 | \n",
" 2.806942 | \n",
" 0.151855 | \n",
" 00:14 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cGZ6OmTtqjUL"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}