{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# pytorch for generating music reviews" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda.is_available: True\n", "available: 1; current: 0\n", "cuda:0\n", "pytorch 0.4.0\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "print('cuda.is_available:', torch.cuda.is_available())\n", "print(f'available: {torch.cuda.device_count()}; current: {torch.cuda.current_device()}')\n", "DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu')\n", "print(DEVICE)\n", "print('pytorch', torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total word_count: 241026; char_count: 1417998\n" ] }, { "data": { "text/plain": [ "0 New Music\\n\\nMt. Joy reached out to us with th...\n", "1 Folk rockers Mt. Joy have debuted their new so...\n", "2 You know we're digging Mt. Joy.\\n\\nTheir new s...\n", "3 Nothing against the profession, but the U.S. h...\n", "4 Connecticut duo **Opia** have released a guita...\n", "Name: content, dtype: object" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import os\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "\n", "BASE_DIR = os.getcwd()\n", "DATA_DIR = os.path.join(BASE_DIR, '..', 'datasets')\n", "\n", "BLOG_CONTENT_FILE = os.path.join(DATA_DIR, f'blog_content_en_sample.json')\n", "BLOG_CONTENT_DF = pd.read_json(BLOG_CONTENT_FILE)\n", "print(f'total word_count: {sum(BLOG_CONTENT_DF.word_count)}; char_count: {sum([len(w) for w in BLOG_CONTENT_DF.content])}')\n", "BLOG_CONTENT_DF.head().content" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_text word_count: 1113633; test_text word_count: 304365\n" ] } ], "source": [ "TRAIN_DF, TEST_DF = train_test_split(BLOG_CONTENT_DF, test_size=0.2, random_state=42)\n", "TRAIN_TEXT, TEST_TEXT = TRAIN_DF.content, TEST_DF.content\n", "print(f'train_text word_count: {sum([len(t) for t in TRAIN_TEXT])}; test_text word_count: {sum([len(t) for t in TEST_TEXT])}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "BPTT = 4 # like the 'n' in n-gram, or order\n", "BS = 512 # batch size\n", "EPOCHS = 5\n", "N_FAC = 42 # number of latent factors\n", "N_HIDDEN = 128" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def pad_start(bptt):\n", " return '\\0' * bptt" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "vocab_size: 70\n", "['\\x00', '\\n', ' ', '!', '\"', '#', '$', '%', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\\\', ']', '^', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']\n", "\n" ] } ], "source": [ "def create_inputs(texts_arr, print_info=False):\n", " # shuffle inputs\n", " texts_arr = texts_arr.sample(frac=1).reset_index(drop=True)\n", " \n", " # pad each new text with leading '\\0' so that we learn how to start\n", " # also, lowercase\n", " texts = ''.join([pad_start(BPTT) + text.lower() for text in texts_arr])\n", "\n", " chars = sorted(list(set(texts)))\n", " vocab_size = len(chars)\n", " if print_info:\n", " print('vocab_size:', vocab_size)\n", " print(chars)\n", " print()\n", "\n", " char_to_idx = {c: i for i, c in enumerate(chars)}\n", " idx_to_char = {i: c for i, c in enumerate(chars)}\n", "\n", " idx = [char_to_idx[text] for text in texts] \n", " return idx, vocab_size, char_to_idx, idx_to_char\n", "\n", "_, VOCAB_SIZE, _, _ = create_inputs(TRAIN_TEXT, True)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import math\n", "import time\n", "\n", "def time_since(since):\n", " now = time.time()\n", " s = now - since\n", " m = math.floor(s / 60)\n", " s -= m * 60\n", " return f'{m}m {s:.0f}s'" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# https://github.com/fastai/fastai/blob/master/fastai/nlp.py\n", "def batchify(data, bs):\n", " if bs == 1:\n", " return torch.tensor([[data[i+o] for i in range(len(data)-BPTT-1)] for o in range(BPTT+1)], dtype=torch.long, device=DEVICE)\n", " else:\n", " num = data.size(0) // bs\n", " data = data[:num*bs]\n", " # invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view().\n", " return data.view(bs, -1).t().contiguous()\n", " \n", "\n", "def get_batch(data, i, seq_len):\n", " seq_len = min(seq_len, len(data) - 1 - i)\n", " return data[i:i+seq_len].to(DEVICE), data[i+1:i+1+seq_len].view(-1).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", " \n", "def plot_loss(losses):\n", " %matplotlib inline\n", " plt.figure()\n", " plt.plot(all_losses)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def batch_train(model, batches, optimizer, criterion=nn.CrossEntropyLoss(), bptt=BPTT):\n", " model.zero_grad()\n", " loss = 0\n", " \n", " for i in range(batches.size(0) - bptt):\n", " xs, ys = get_batch(batches, i, bptt)\n", " output = model(xs)\n", " loss += criterion(output, ys)\n", " \n", " loss.backward()\n", " if optimizer:\n", " optimizer.step()\n", " \n", " return loss.item() / (batches.size(0) - bptt)\n", "\n", "def batchless_train(model, batches, optimizer, start, print_every, char_to_idx, idx_to_char, seed='the ', max_sample_length=100, criterion=nn.CrossEntropyLoss(), bptt=BPTT):\n", " xs = np.stack(batches[:-1], axis=1) # history\n", " ys = np.stack(batches[-1:][0]) # target\n", "\n", " total_loss = torch.Tensor([0])\n", " for i in range(xs.shape[0]):\n", " model.zero_grad()\n", " output = model(torch.tensor(xs[i], dtype=torch.long, device=DEVICE))\n", "\n", " loss = criterion(output, torch.tensor([ys[i]], dtype=torch.long, device=DEVICE))\n", " \n", " loss.backward()\n", " if optimizer:\n", " optimizer.step()\n", " \n", " # Get the Python number from a 1-element Tensor by calling tensor.item()\n", " total_loss += loss.item()\n", " \n", " if i % print_every == 0:\n", " print(f'{time_since(start)} ({i} {i / xs.shape[0] * 100:.2f}%) {loss:.4f}')\n", " print(f'Epoch {i} sample:')\n", " sample(model, char_to_idx, idx_to_char, seed=seed, max_length=max_sample_length)\n", " \n", " return total_loss# / xs.shape[0]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sample(model, char_to_idx, idx_to_char, seed=pad_start(BPTT), max_length=20, bptt=BPTT, sample=True):\n", " with torch.no_grad(): # no need to track history in sampling\n", " output_idx = [char_to_idx[c] for c in seed[-bptt:]]\n", "\n", " for i in range(max_length):\n", " h_idxs = torch.tensor(output_idx[-bptt:], dtype=torch.long, device=DEVICE).view(-1, 1)\n", " output = model(h_idxs.transpose(0,1))\n", " if sample:\n", " # sample from distribution\n", " idx = torch.multinomial(output[-1].exp(), 1).item()\n", " else:\n", " # get most probable\n", " topi = output.topk(1)[1]\n", " idx = topi[0][0]\n", " if idx == 0:\n", " break\n", " else:\n", " output_idx.append(idx)\n", "\n", " sample_text = ''.join([idx_to_char[i] for i in output_idx])\n", " print(sample_text)\n", " #print(output_idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## with n-grams\n", "\n", "Another [n-gram music reviews](http://nbviewer.jupyter.org/github/iconix/openai/blob/master/nbs/n-gram%20music%20reviews.ipynb) model, implemented this time in PyTorch.\n", "\n", "Guiding PyTorch tutorial: [An Example: N-Gram Language Modeling](https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html#an-example-n-gram-language-modeling)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class NGramLanguageModel(nn.Module):\n", " \n", " def __init__(self, vocab_size, hidden_size, n_fac, bptt):\n", " super(NGramLanguageModel, self).__init__()\n", " \n", " self.embedding = nn.Embedding(vocab_size, n_fac)\n", " self.linear1 = nn.Linear(bptt * n_fac, hidden_size)\n", " self.linear2 = nn.Linear(hidden_size, vocab_size)\n", " \n", " def forward(self, inputs):\n", " inputs = self.embedding(inputs).view((1, -1))\n", " out = F.relu(self.linear1(inputs))\n", " out = self.linear2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0m 46s (0 0.00%) 4.4895\n", "Epoch 0 sample:\n", "the x=$=el1$\n", "8m 56s (500000 44.69%) 4.0249\n", "Epoch 500000 sample:\n", "the rne\n", ",n ef apdrmlggi entofs_.tis sa skrcutttta sd\"woearotcn*hvf *sno caliiengwsbuecfhuirrl wrsoabe'ua\n", "17m 6s (1000000 89.39%) 2.7267\n", "Epoch 1000000 sample:\n", "the t epfdynodoo *tlruesrs sdtahesl mrh ev tajoofiyg ihe\n", " tuoa he rnl m.lorsgagcn,ts vo so.essgao su\n", "19m 48s (0 0.00%) 2.7462\n", "Epoch 0 sample:\n", "the itrautcwllctahhdmranyanntonepoir er tteghssseayit \n", "kptn in vuih'nrnstwuiii n niumfd tes lerteueid/ae\n", "27m 59s (500000 44.69%) 3.0083\n", "Epoch 500000 sample:\n", "the hbff asemys tt.paelosrhulc -ihltaihmesmu nswer l doroldte\n", " oa w\n", "oaopnauskrrssthsut\n", " ak dhl einih a\n", "36m 10s (1000000 89.39%) 3.9739\n", "Epoch 1000000 sample:\n", "the emialsnao.h\n", " vte et,tsteew.rei ae ho*eodthdk a*t\n", " ss r aren**a d o\n", "mo'i es l a c ofah msariir wni \n", "38m 52s (0 0.00%) 3.6807\n", "Epoch 0 sample:\n", "the oyeailydrcsutm ,yo ls\"eft, rudl eoi thogs on secni.iibr'ey iyi eal\"\n", "toonn hks wtnm,l leoi vueseiedp\n", "47m 5s (500000 44.69%) 4.0389\n", "Epoch 500000 sample:\n", "the fu\n", "jn sottee chh seehllra takdsu gntc \"a todnnr\n", "nno\n", "i sagehn er noro e u nupa* seuee c rbaonup bbk\n", "55m 37s (1000000 89.39%) 2.9309\n", "Epoch 1000000 sample:\n", "the hpee\n", "igekeieietden'oi ascrataisise\n", "\n", "h\n", " tathh\n", ".yesaycsapeaek dhhealmotslde*v \n", "spaelymblons*o\n", "od ats\n", "58m 24s (0 0.00%) 1.9723\n", "Epoch 0 sample:\n" ] }, { "ename": "RuntimeError", "evalue": "cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1524586445097/work/aten/src/THC/generic/THCStorage.c:36", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mngram\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNGramLanguageModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mVOCAB_SIZE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN_HIDDEN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN_FAC\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBPTT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDEVICE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mngram\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.005\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mall_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mngram\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTRAIN_TEXT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprint_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m500000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_loop\u001b[0;34m(model, optimizer, text, batch_size, seed, max_sample_length, epochs, print_every, plot_every, criterion)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mbatches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatchify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDEVICE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatchless_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprint_every\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchar_to_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx_to_char\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_sample_length\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprint_every\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mbatchless_train\u001b[0;34m(model, batches, optimizer, start, print_every, char_to_idx, idx_to_char, seed, max_sample_length, criterion, bptt)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{time_since(start)} ({i} {i / xs.shape[0] * 100:.2f}%) {loss:.4f}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Epoch {i} sample:'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchar_to_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx_to_char\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_length\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_sample_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtotal_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36msample\u001b[0;34m(model, char_to_idx, idx_to_char, seed, max_length, bptt, sample)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# sample from distribution\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultinomial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;31m# get most probable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1524586445097/work/aten/src/THC/generic/THCStorage.c:36" ] } ], "source": [ "ngram = NGramLanguageModel(VOCAB_SIZE, N_HIDDEN, N_FAC, BPTT).to(DEVICE)\n", "optimizer = optim.Adam(ngram.parameters(), lr=0.005)\n", "all_losses = train_loop(ngram, optimizer, TRAIN_TEXT, batch_size=1, plot_every=1, print_every=500000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sample(ngram, char_to_idx, idx_to_char, seed='the ', max_length=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Observations**:\n", "- Training, even on a sample 2K reviews, is _slow_ (5 epochs in 67m 18s). Could we speed up with:\n", " - Batching\n", " - Adaptive learning rates (although this may make it train better but not necessarily faster)\n", " - Using PyTorch implementations of RNNs/LSTMs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## with custom rnn" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class RNN(nn.Module):\n", " def __init__(self, vocab_size, hidden_size, n_fac, bptt, batch_size=BS):\n", " super(RNN, self).__init__()\n", " self.hidden_size = hidden_size\n", " \n", " self.embeddings = nn.Embedding(vocab_size, n_fac)\n", " self.i2h = nn.Linear(bptt * n_fac + hidden_size, hidden_size)\n", " self.i2o = nn.Linear(bptt * n_fac + hidden_size, vocab_size)\n", " self.o2o = nn.Linear(hidden_size + vocab_size, vocab_size)\n", " self.dropout = nn.Dropout(0.1)\n", " self.softmax = nn.LogSoftmax(dim=1)\n", " \n", " self.init_hidden(batch_size)\n", " \n", " # NOTE: this example only works as-is in PyTorch 0.4+\n", " # https://stackoverflow.com/questions/50475094/runtimeerror-addmm-argument-mat1-position-1-must-be-variable-not-torch\n", " def forward(self, inputs):\n", " #bs = inputs[0].size(0)\n", " # dynamic batch sizing\n", " #if self.batch_size != bs: self.init_hidden(bs)\n", " \n", " embeds = self.embeddings(inputs).view((1, -1))\n", " combined_i = torch.cat((embeds, self.hidden), 1)\n", " hidden = self.i2h(combined_i)\n", " # detach from history of the last run\n", " self.hidden = hidden.detach()\n", " output = self.i2o(combined_i)\n", " combined_o = torch.cat((self.hidden, output), 1)\n", " output = self.o2o(combined_o)\n", " output = self.dropout(output)\n", " output = self.softmax(output)\n", " return output\n", " \n", " def init_hidden(self, bs):\n", " # 1 RNN layer\n", " self.batch_size = bs\n", " self.hidden = torch.zeros(1, self.hidden_size).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0m 46s (0 0.00%) 4.2767\n", "Epoch 0 sample:\n", "the 1~af9#%c_~ild$\"?rf/b\\|\n", "j)za#!n&\n", "11m 24s (500000 44.69%) 2.3363\n", "Epoch 500000 sample:\n", "the be stot eer/lrlck0sttmasd of corsdetav> in+ therk, fl yeint ulh, bta0krauk i#duk_.n8 dot9et8y imeos \n", "22m 5s (1000000 89.39%) 3.0072\n", "Epoch 1000000 sample:\n", "the bes!y\n", "2-counq** 4lewbor albzi\\\"\n", "25m 20s (0 0.00%) 7.4589\n", "Epoch 0 sample:\n", "the %ur rt ofekulg(tha onr orin\" 3orsof yr and !etes aoseve_ pop gope\n", "ntys^pc nouthericheof t7e, cse ma{w st\n", "97m 47s (0 0.00%) 38.1602\n", "Epoch 0 sample:\n", "the k| t vea lingtfeyeas h-lasgtean ote.p\n", "108m 28s (500000 44.69%) 1.8164\n", "Epoch 500000 sample:\n", "the 2*\n", "v &\n", ", d\n", "lleituts woleez:z.ve4h @ th\\aw*ivis sipbiuilasw tod^v. \n", "**+\n", "**d-tos\n", "\n", "\n", "119m 18s (1000000 89.39%) 5.7211\n", "Epoch 1000000 sample:\n", "the \n", "Training time: 7308.59s\n" ] } ], "source": [ "rnn = RNN(VOCAB_SIZE, N_HIDDEN, N_FAC, BPTT).to(DEVICE)\n", "optimizer = optim.Adam(rnn.parameters(), lr=0.005)\n", "all_losses = train_loop(rnn, optimizer, TRAIN_TEXT, criterion=nn.NLLLoss(), batch_size=1, plot_every=1, print_every=500000)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD8CAYAAACyyUlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd8lfX9///Hi0Ag7BUwJkS2yB5hqHXhAhyggsU6cGJta/066qh+PvzUDrtcraO2VXG0MgVELKJotQ4gAcIUCDuArLBHQpLX749zYdN8AueEkJyT5Hm/3c6Nc97X+7rer1x68sx1Xe9zHXN3REREyqJGtAsQEZHKT2EiIiJlpjAREZEyU5iIiEiZKUxERKTMFCYiIlJmChMRESkzhYmIiJSZwkRERMqsZrQLqCjNmzf31q1bR7sMEZFKJSMjY4e7J4brV23CpHXr1qSnp0e7DBGRSsXM1kfST6e5RESkzBQmIiJSZgoTEREpM4WJiIiUmcJERETKTGEiIiJlpjAREZEyU5iIiFRRh48U8OsZy9m0+1C5j1VtPrQoIlKdZG7czX3jF7J6+wFaNa3LDQNOK9fxFCYiIlXIkYJC/jQ7iz99kkWLBrV567b+fK9D83IfV2EiIlJFZG3bx73jMlm8aQ9X9Urm/7uyC40SalXI2AoTEZFKrrDQef3Ldfzmn99QNz6OF6/vzZBuSRVag8JERKQS27T7EA+Mz+SrNTsZ2KkFT13djRYN61R4HQoTEZFKyN2ZNH8Tj09bSqE7T13dje/3bYWZRaUehYmISCWzY38uP5+8mA+XbaVv6yb8YURPUpvVjWpNChMRkUpk1rKtPDJ5EXsP5fPI4E7cfk5b4mpE52ikKIWJiEglsO/wEZ6cvozx6dmckdSQt27vQadTGka7rO8oTEREYtzXa3Zy//hMtuw5xI8vaMc9F3YkvmZs3cBEYSIiEqMOHyng9zNX8Lcv1nJa07pM+OFZ9DmtSbTLKpHCREQkBi3ZtId7xy1k1bb93DAglZ8POYO68bH7Kzt2KxMRqYbyCwp56dPVPPfxKprVj2fsrf04r2NitMsKS2EiIhIjVm/fz33jM8ncuJsre5zKE0O70LhufLTLikjEV3DMLM7MFpjZ9OD122a2wsyWmNmrZlYraDcze97MssxskZn1LrKNUWa2KniMKtLex8wWB+s8b8GnbsysqZnNCvrPMrMm4cYQEalsCgudsV+u47LnP2fdjgP88bpePH9dr0oTJFC67zO5B1he5PXbQCegG5AA3B60DwY6BI/RwEsQCgZgDNAf6AeMORoOQZ/RRdYbFLQ/DHzs7h2Aj4PXxxxDRKSy2bz7EDe9Opcx05YyoG0zPrz3XK7ocWq0yyq1iMLEzFKAy4C/Hm1z9xkeAOYCKcGiocAbwaKvgcZmlgRcCsxy9xx33wXMAgYFyxq6+1fBtt4AhhXZ1tjg+dhi7SWNISJSKbg7UxZs4tJnPyNj/S5+eVVXXru5Ly2jcF+tkyHSaybPAg8CDYovCE5v3UjoyAUgGdhYpEt20Ha89uwS2gFauvsWAHffYmYtwoyxpVhtowkduZCamhrBjykiUv5yDuTx2JTFzFj8LX1Oa8IfRvSgdfN60S6rTMKGiZldDmxz9wwzO7+ELi8Cn7n750dXKaGPn0D7ccuKZB13fwV4BSAtLS3cNkVEyt3sb7by0KTF7D6Yx4ODTufOc9vFxO1QyiqSI5OzgSvNbAhQB2hoZm+5+w1mNgZIBO4s0j8baFXkdQqwOWg/v1j7p0F7Sgn9AbaaWVJwVJIEbAszhohITNqfm88v31/GP+ZupNMpDRh7Sz86nxo7t0Mpq7DXTNz9EXdPcffWwEhgdhAktxO6DnKduxcWWWUacFMw42oAsCc4VTUTuMTMmgQX3i8BZgbL9pnZgGAW103A1CLbOjrra1Sx9pLGEBGJOXPX5jD4uc94Z95G7jyvLVN/cnaVChIo2+dMXgbWA18FM3knu/sTwAxgCJAFHARuAXD3HDN7EpgXrP+Eu+cEz+8CXic0K+yD4AHwFDDezG4DNgAjgvYSxxARiSW5+QU8/eFKXvl8DSlNEhh/55n0bd002mWVCwtNoKr60tLSPD09PdpliEg1sXTzHu4bl8mKrfu4rl8qj152BvVrV77PiZtZhrunhetX+X4yEZEYll9QyJ8/W8OzH62kcd14Xru5Lxd0ahF+xUpOYSIicpKs23GA+8YvZP6G3VzWLYlfDOtKk3qV51PsZaEwEREpI3fn7Tkb+OX7y6kVZzw3sidX9jg1at/HHg0KExGRMti69zAPTlzEv1Zu55wOzfnt8O4kNUqIdlkVTmEiInKC3svczGNTlpCbX8CTQ7tww4DTqtXRSFEKExGRUtp9MI//mbqU9zI307NVY56+tgdtE+tHu6yoUpiIiJTCpyu28eDEReQcyOOBSzryw/PaUTMutr6PPRoUJiIiETiYl88v31/O23M20KFFfV69uS9dkxtFu6yYoTAREQkjY30O943PZEPOQe44pw33X3I6dWrFRbusmKIwERE5hrz8Qp79aCUv/2s1SY0S+McdAxjQtlm0y4pJChMRkRJ88+1e7h2XyfIte/l+Wiseu/wMGtSpFe2yYpbCRESkiIJC56+fr+EPH66kYUJN/nJTGhd3bhntsmKewkREJLBh50Hun7CQeet2cWmXlvzqqm40q1872mVVCgoTEan23J1x8zby5PRl1DDj6Wt7cFWv5Gr7AcQToTARkWpt277DPDxpMbO/2cZZ7ZrxuxE9SG5c/W6HUlYKExGptmYs3sKj7y7mYF4BY67ozKgzW1OjCnwfezQoTESk2tlz8Ahjpi1hysLNdE9pxNPX9qR9i+p9O5SyUpiISLXy+art/GzCIrbvz+XeizryowvaUUu3QymziPegmcWZ2QIzmx68/omZZZmZm1nzIv3MzJ4Pli0ys95Flo0ys1XBY1SR9j5mtjhY53kLrnqZWVMzmxX0n2VmTcKNISJSkkN5BYyZuoQb/zaXerXjePdHZ3HPRR0UJCdJafbiPcDyIq+/AC4C1hfrNxjoEDxGAy9BKBiAMUB/oB8w5mg4BH1GF1lvUND+MPCxu3cAPg5eH3MMEZGSLNiwi8ue/5yxX63n1rPb8P5Pz6F7SuNol1WlRBQmZpYCXAb89Wibuy9w93UldB8KvOEhXwONzSwJuBSY5e457r4LmAUMCpY1dPev3N2BN4BhRbY1Nng+tlh7SWOIiHznSEEhf/hwBde89CW5+YX8/Y7+/O8VnXVfrXIQ6TWTZ4EHgQYR9E0GNhZ5nR20Ha89u4R2gJbuvgXA3beYWYswY2yJ5IcRkapv1dZ93Dt+IUs27eWa3imMubIzDXU7lHITNkzM7HJgm7tnmNn5EWyzpHl1fgLtJzLGf3cyG03oNBipqalhNikiVUFhofPqF2v57cwV1K9dk5dv6MOgrqdEu6wqL5Ijk7OBK81sCFAHaGhmb7n7Dcfonw20KvI6BdgctJ9frP3ToD2lhP4AW80sKTgqSQK2hRnjv7j7K8ArAGlpaeECSkQquY05B3lgQiZz1uZw0Rkt+fXV3UhsoNuhVISw10zc/RF3T3H31sBIYPZxggRgGnBTMONqALAnOFU1E7jEzJoEF94vAWYGy/aZ2YBgFtdNwNQi2zo662tUsfaSxhCRasjdGZ++kcHPfc7SzXv57fDu/OWmPgqSCnTCnzMxs58Suo5yCrDIzGa4++3ADGAIkAUcBG4BcPccM3sSmBds4gl3zwme3wW8DiQAHwQPgKeA8WZ2G7ABGBG0lziGiFQ/2/fl8sjkxXy0fCv92zTl9yN60Kpp3WiXVe1YaAJV1ZeWlubp6enRLkNETqJ/LvmWR99dzL7cfB689HRuPbuNbodykplZhrunheunT8CLSKWz9/ARHp+2jEnzs+ma3JB3ru1Jh5aRTDaV8qIwEZFK5cusHTwwIZOt+3L56cD23H2hPsUeCxQmIlIpHMor4Df//IbXv1xH2+b1mPjDM+mV2iT8ilIhFCYiEvMy1ufwwIRFrN1xgJvPas1DgzqREK9PsccShYmIxKzDRwp4ZtZK/vL5Gk5tnMA/7hjAme2aRbssKYHCRERi0qLs3dw/PpNV2/ZzXb9UHr3sDOrX1q+sWKX/MiISU/LyC/nT7FW88OlqEuvXZuyt/TivY2K0y5IwFCYiEjOWb9nL/eMzWbYldHPG/72iM40SdHPGykBhIiJRl19QyMv/Ws1zH6+iUUI8f7kpjYs7t4x2WVIKChMRiaqsbfu4f3wmmdl7uLx7Ek8M7UrTevHRLktKSWEiIlFRUOi8+u+1/O7DFdSLj+NPP+jF5d1PjXZZcoIUJiJS4dbtOMADEzJJX7+Lizu35FdX6VbxlZ3CREQqTGGh89ac9fx6xjfUjDOevrYHV/VKJvTtE1KZKUxEpEJk7zrIgxMX8eXqnZzXMZHfXNOdUxrViXZZcpIoTESkXLk74+Zt5BfvL8fd+fXV3RjZt5WORqoYhYmIlJutew/z0KRFfLpiO2e2bcZvh3fXF1dVUQoTETnp3J0pCzcxZupS8goKefzKLtw44DR9cVUVpjARkZNq+75cHn13MR8u20qf05rw+xE9aNO8XrTLknKmMBGRk2bG4i08NmUJ+3Pz+fmQTtz2vbbE6WikWoj468nMLM7MFpjZ9OB1GzObY2arzGycmcUH7bWD11nB8tZFtvFI0L7CzC4t0j4oaMsys4eLtJd6DBGpeLsO5HH3Pxbwo7fnk9Ikgffv/h6jz22nIKlGSvNdl/cAy4u8/g3wjLt3AHYBtwXttwG73L098EzQDzPrDIwEugCDgBeDgIoDXgAGA52B64K+pR5DRCreR8u2csmzn/HPJVu4/+KOTL7rLH0fezUUUZiYWQpwGfDX4LUBA4GJQZexwLDg+dDgNcHyC4P+Q4F33D3X3dcCWUC/4JHl7mvcPQ94Bxh6gmOISAXZe/gID0zI5PY30mlWL56pP/4ed1/YgZr6PvZqKdJrJs8CDwJH/9xoBux29/zgdTaQHDxPBjYCuHu+me0J+icDXxfZZtF1NhZr73+CY+yI8OcRkTL4bOV2Hpq0iG37crl7YHvuHtiB+JoKkeosbJiY2eXANnfPMLPzjzaX0NXDLDtWe0n/Bx6vf7jxv2Nmo4HRAKmpqSWsIiKlsT83n1/NWM7f52ygfYv6TL6hDz1aNY52WRIDIjkyORu40syGAHWAhoSOVBqbWc3gyCEF2Bz0zwZaAdlmVhNoBOQUaT+q6Dolte84gTH+i7u/ArwCkJaW9n/CRkQi9/WanfxsYibZuw4x+ty23HdxR+rUiot2WRIjwh6Xuvsj7p7i7q0JXUCf7e7XA58Aw4Nuo4CpwfNpwWuC5bPd3YP2kcFMrDZAB2AuMA/oEMzcig/GmBasU9oxROQkO5RXwOPvLWXkK18TZ8aEO8/k50POUJDIfynL50weAt4xs18AC4C/Be1/A940syxCRwsjAdx9qZmNB5YB+cCP3b0AwMx+AswE4oBX3X3piYwhIidXxvpdPDAhk7U7DjDqzNN4aHAn6sbr42nyf1l1+YM+LS3N09PTo12GSKVw+EgBz3y0kr98toakRgn8bnh3zmrfPNplSRSYWYa7p4Xrpz8xROS/LM7ew33jF7Jq236u69eKnw85gwZ1akW7LIlxChMRASAvv5A/fZLFC59k0bx+PK/d0pcLTm8R7bKkklCYiAjffLuX+8ZlsmzLXq7ulcyYK7rQqK6ORiRyChORaiy/oJA/f7aGZz9aSaOEWrxyYx8u6XJKtMuSSkhhIlJNZW3bz/0TMsncuJvLuiXx5LCuNK0XH+2ypJJSmIhUMwWFzmtfrOV3M1eQEB/HH6/rxRU9To12WVLJKUxEqpH1Ow/wwIRM5q3bxUVntORXV3elRYM60S5LqgCFiUg1UFjovD1nPb+a8Q0144zfj+jBNb2T0c225WRRmIhUcZt2H+LBiZl8kbWTczo057fDu5PUKCHaZUkVozARqaLcnQnp2TwxfRnuzq+u6sZ1/VrpaETKhcJEpArauvcwD09axCcrttO/TVN+P6IHrZrWjXZZUoUpTESqEHdn6sLNjJm2lNz8AsZc0ZlRZ7amhr6LXcqZwkSkitixP5dH313MzKVb6Z3amN+P6EHbxPrRLkuqCYWJSBXwweItPDplCfsP5/Pw4E7ccU5b4nQ0IhVIYSJSie0+mMf/Tl3KtMzNdEtuxB+u7UHHlg2iXZZUQwoTkUrq4+VbeXjyYnYdyOO+izty1/ntqBUX9stTRcqFwkSkktl7+AhPvreMCRnZdDqlAa/d3JeuyY2iXZZUcwoTkUrk81XbeWjiIr7de5gfX9COn17Ygdo19V3sEn0KE5FK4EBuPr/+YDlvfb2Bdon1mPyjs+nZqnG0yxL5TtgTrGZWx8zmmlmmmS01s8eD9oFmNt/MlpjZWDOrGbSbmT1vZllmtsjMehfZ1igzWxU8RhVp72Nmi4N1nrfgI7pm1tTMZgX9Z5lZk3BjiFQ1c9bsZPBzn/P2nA3c/r02vP/TcxQkEnMiuVqXCwx09x5AT2CQmZ0FjAVGuntXYD1wNBwGAx2Cx2jgJQgFAzAG6A/0A8YcDYegz+gi6w0K2h8GPnb3DsDHwetjjiFSlRw+UsAT7y1j5F++xgzGjT6Txy7vTJ1aOq0lsSdsmHjI/uBlreBRAOS6+8qgfRZwTfB8KPBGsN7XQGMzSwIuBWa5e4677wrWGRQsa+juX7m7A28Aw4psa2zwfGyx9pLGEKkS5m/YxZDnPufVL9Zy44DT+OCec+jXpmm0yxI5poiumZhZHJABtAdeAOYCtcwszd3TgeFAq6B7MrCxyOrZQdvx2rNLaAdo6e5bANx9i5m1CDPGlkh+HpFYlZtfwLMfreLP/1pNUqME3r69P2e3bx7tskTCiihM3L0A6GlmjYF3gS7ASOAZM6sNfAjkB91L+titn0D78US0jpmNJnQajNTU1DCbFImuJZv2cP/4TFZs3cf301rx2OVn0KBOrWiXJRKRUn3Cyd13A58Cg4LTUue4ez/gM2BV0C2b/xylAKQAm8O0p5TQDrD16Omr4N9tYcYoXu8r7p7m7mmJiYml+VFFKsyRgkKe/Wglw174gl0H83jt5r78Znh3BYlUKpHM5koMjkgwswTgIuCbo6ecgiOTh4CXg1WmATcFM64GAHuCU1UzgUvMrElw4f0SYGawbJ+ZDQhmcd0ETC2yraMX9kcVay9pDJFKZcPOg1z14hc8+9EqLu+exIf3nssFnVqEX1EkxkRymisJGBtcN6kBjHf36Wb2OzO7PGh7yd1nB/1nAEOALOAgcAuAu+eY2ZPAvKDfE+6eEzy/C3gdSAA+CB4ATwHjzew2YAMw4nhjiFQmGet3MfqNdArcefmGPgzqekq0SxI5YRaaQFX1paWleXp6erTLEAHgvczN3D8hk1Mb1eG1W/rRpnm9aJckUiIzy3D3tHD99Al4kQrk7rz46Wp+N3MF/Vo35c839qFJvfholyVSZgoTkQqSl1/Iz99dzMSMbK7qlcxT13TTfbWkylCYiFSAPQeP8MO3MvhqzU7+30UduOfCDgR3DRKpEhQmIuVs/c4D3PL6PLJzDvHs93syrFdy+JVEKhmFiUg5Sl+Xw+g3M3B33rq9v26JIlWWwkSknExduImfTVxEcuMEXru5L601Y0uqMIWJyEnm7vxpdhZ/mLWSfm2a8ucbNGNLqj6FichJlJdfyCOTFzNpfjZX90rm15qxJdWEwkTkJNl9MI8738xgztoc7ru4I3cPbK8ZW1JtKExEToJ1Ow5w6+vzyN51iOdG9mRoT83YkupFYSJSRvPW5TD6jdCtev5+R3/SWmvGllQ/ChORMpi6cBM/m7CIlCYJvKoZW1KNKUxEToC78/zHWTzz0Ur6twndY6txXc3YkupLYSJSSrn5BTwyaTGTF2zi6t7JPHV1d+Jrlup75kSqHIWJSCnsPpjH6DczmLs2h/sv7shPNGNLBFCYiERs3Y7QPbY27daMLZHiFCYiETg6Y8vM+PvtmrElUpzCRCSMKQs28eDERaQ0Dd1j67RmmrElUpzCROQY3J3nPl7Fsx+tYkDbpvz5hjQa1a0V7bJEYlLYKShmVsfM5ppZppktNbPHg/YLzWy+mS00s3+bWfugvbaZjTOzLDObY2ati2zrkaB9hZldWqR9UNCWZWYPF2lvE2xjVbDN+HBjiJwMufkF3Dc+k2c/WsXwPim8cWt/BYnIcUQynzEXGOjuPYCewCAzGwC8BFzv7j2BvwOPBf1vA3a5e3vgGeA3AGbWGRgJdAEGAS+aWZyZxQEvAIOBzsB1QV+CdZ9x9w7ArmDbxxxD5GTYdSCPG/86l3cXbOJnl57O74Zr6q9IOGHfIR6yP3hZK3h48GgYtDcCNgfPhwJjg+cTgQstNHdyKPCOu+e6+1ogC+gXPLLcfY275wHvAEODdQYG2yDY5rAwY4iUydodB7j6pS9ZmL2bP17Xix9foKm/IpGI6JpJcPSQAbQHXnD3OWZ2OzDDzA4Be4EBQfdkYCOAu+eb2R6gWdD+dZHNZgdtHO1fpL1/sM5ud88vof+xxthRrO7RwGiA1NTUSH5UqcbmrNnJnW9lUMOMf9zRnz6nacaWSKQiOnZ394LgdFYK0M/MugL3AkPcPQV4DXg66F7Sn3F+EtuPN0bxul9x9zR3T0tMTCxhFZGQdxdkc8Pf5tCsXjxTfnS2gkSklEo1m8vdd5vZp4Sub/Rw9znBonHAP4Pn2UArINvMahI6BZZTpP2oFP5zaqyk9h1AYzOrGRydFO1/rDFESsXdefajVTz38SrOateMl67vowvtIicgktlciWbWOHieAFwELAcamVnHoNvFQRvANGBU8Hw4MNvdPWgfGczEagN0AOYC84AOwcyteEIX6acF63wSbINgm1PDjCESsdz8Au4dt5DnPl7FtWkpvH5LPwWJyAmK5MgkCRgbXDepAYx39+lmdgcwycwKCc20ujXo/zfgTTPLInS0MBLA3Zea2XhgGZAP/NjdCwDM7CfATCAOeNXdlwbbegh4x8x+ASwItn3MMUQilXMgjzvfTGfeul387NLT+dH57XShXaQMrLr8QZ+Wlubp6enRLkNiwJrt+7nl9Xls2XOYp6/tweXdT412SSIxy8wy3D0tXD99Al6qla/X7OTONzOoWcN4Z/QAeqc2iXZJIlWCwkSqjUkZ2Tw8eRGpTevy+i39aNW0brRLEqkyFCZS5bk7z8xayfOzs0Iztm7oQ6MEXWgXOZkUJlKlHT5SwIMTFzEtczPXpqXwi2HddGsUkXKgMJEqa+f+XO58M4P09bt4cNDp3HWeZmyJlBeFiVRJq7fv55bX5rF172Fe+EFvLuueFO2SRKo0hYlUOV+t3skP38qgVpzxD83YEqkQChOpUo7O2DqtWT1eu7mvZmyJVBCFiVQJ7s7Ts1byx9lZnN2+GS9erxlbIhVJYSKV3uEjBfxs4iLey9zMyL6teHJYV2rFacaWSEVSmEiltnN/LqPfzCBj/S4eHtyJO89tqxlbIlGgMJFKK2vbfm59PTRj66XrezO4m2ZsiUSLwkQqpS9X7+CHb2YQX7MG74weQC/N2BKJKoWJVDoT0jfyyOTFtGlej1c1Y0skJihMpNIoLHT+MGsFL3yymnM6NOeF63vTsI5mbInEAoWJVAqHjxTwwIRMpi/awnX9UnliaBfN2BKJIQoTiXk79+dyxxvpLNi4m58P6cQd52jGlkisUZhITMvato9bXp/H9n25vHR9bwZ11YwtkVgU9jyBmdUxs7lmlmlmS83s8aD9czNbGDw2m9mUoN3M7HkzyzKzRWbWu8i2RpnZquAxqkh7HzNbHKzzvAV/dppZUzObFfSfZWZNwo0hVceXWTu46sUvOZRXyLjRZypIRGJYJCedc4GB7t4D6AkMMrMB7n6Ou/d0957AV8DkoP9goEPwGA28BKFgAMYA/YF+wJij4RD0GV1kvUFB+8PAx+7eAfg4eH3MMaTqGD9vIze9OpekRnWY8uOz6NGqcbRLEpHjCBsmHrI/eFkrePjR5WbWABgITAmahgJvBOt9DTQ2syTgUmCWu+e4+y5gFqFgSgIauvtX7u7AG8CwItsaGzwfW6y9pDGkkissdH77z294cNIizmzXjIl3nUVKE039FYl1EU2HMbM4M1sIbCMUCHOKLL6K0NHD3uB1MrCxyPLsoO147dkltAO0dPctAMG/LcKMIZXY4SMF3P2PBbz46Wp+0D+VV2/uq6m/IpVERBfg3b0A6GlmjYF3zayruy8JFl8H/LVI95Km2fgJtB9PROuY2WhCp8FITU0Ns0mJph3BjK2FG3fz6JAzuP2cNpqxJVKJlGqivrvvBj4luKZhZs0IXf94v0i3bKBVkdcpwOYw7SkltANsPXr6Kvh3W5gxitf7irunuXtaYmJixD+nVKxVW/cx7IUvWL5lLy9d34c7dLNGkUonktlcicERCWaWAFwEfBMsHgFMd/fDRVaZBtwUzLgaAOwJTlHNBC4xsybBhfdLgJnBsn1mNiCYxXUTMLXIto7O+hpVrL2kMaSS+feqHVz90pfk5h+dsXVKtEsSkRMQyWmuJGCsmcURCp/x7j49WDYSeKpY/xnAECALOAjcAuDuOWb2JDAv6PeEu+cEz+8CXgcSgA+CB8G2x5vZbcAGQuF1zDGkcnln7gYem7KEdon1efWWviQ3Toh2SSJygiw0garqS0tL8/T09GiXIQQztmau4OV/rebcjom88INeNNCFdpGYZGYZ7p4Wrp8+AS8V6vCRAu4bv5AZi7/l+v6pPH5lF2rqHlsilZ7CRCrM9n2hGVuZ2bt57LIzuO17mrElUlUoTKRCrNoausfWzv15vHxDHy7togvtIlWJwkTK3b9X7eCutzKoEx/H+DvPpFtKo2iXJCInmcJEytXRGVvtW9TnbzdrxpZIVaUwkXJRWOj8ZuY3/PlfazivYyJ/0owtkSpNYSIn3aG80IytD5Z8y40DTmPMFZ01Y0ukilOYyEnj7izetIf/mbqURdm7+Z/LO3Pr2a01Y0ukGlCYSJlt23eYqQs2MzGlUK2MAAALOklEQVQjmxVb91EvPo4/39CHSzRjS6TaUJjICcnLL2T2N1uZmJHNJyu2U1Do9E5tzK+u6sZl3ZNolKDrIyLVicJESmXJpj1MzMhm6sJN7Dp4hJYNazP63LZc0zuF9i3qR7s8EYkShYmEtXN/LlMWhk5jLd+yl/i4GlzcpSUj+qRwTodE4mromohIdacwkRIdKSjkk2+2MTEjm9nfbCO/0OmR0ognh3Xliu5JNK4bH+0SRSSGKEzkvyzfspeJGdlMWbCJnQfyaF6/Nrd+rw3D+6TQsWWDaJcnIjFKYSLkHMhj2sJNTJyfzZJNe6kVZ1x0RktGpKVwbodEfUZERMJSmFRT+QWF/GvldiZmZPPR8q0cKXC6Jjfk8Su7cGWPU2lST6exRCRyCpNqZuXWfUzMyGby/E3s2J9Ls3rx3HRma4b3SeGMpIbRLk9EKimFSTWw+2Ae72WGZmNlZu+hZg1jYKcWjEhrxfmnJ1JLp7FEpIwUJlVUQaHz2arQaaxZS7eSV1BIp1Ma8D+Xd2ZYz1NpVr92tEsUkSokbJiYWR3gM6B20H+iu4+x0A2XfgGMAAqAl9z9+aD9OWAIcBC42d3nB9saBTwWbPoX7j42aO8DvA4kADOAe9zdzawpMA5oDawDrnX3Xccbo7rL2rafiRnZvLsgm617c2lStxY/6J/KiLQUupyq7xERkfIRyZFJLjDQ3febWS3g32b2AXAG0Aro5O6FZtYi6D8Y6BA8+gMvAf2DYBgDpAEOZJjZNHffFfQZDXxNKEwGAR8ADwMfu/tTZvZw8PqhY41Rtl1Ree05dITpi0KnsRZs2E1cDeOC0xN5/MpWDOzUgviaOo0lIuUrbJi4uwP7g5e1gocDdwE/cPfCoN+2oM9Q4I1gva/NrLGZJQHnA7PcPQfAzGYBg8zsU6Chu38VtL8BDCMUJkOD9QDGAp8SCpMSx3D3LSe4HyqdgkLni6wdTMzIZubSb8nNL6Rjy/o8OuQMhvVKJrGBTmOJSMWJ6JqJmcUBGUB74AV3n2Nm7YDvm9lVwHbgp+6+CkgGNhZZPTtoO157dgntAC2PBoS7byly9HOsbVX5MFm74wATMzYyef4mtuw5TKOEWny/bytG9GlF1+SGut27iERFRGHi7gVATzNrDLxrZl0JXUM57O5pZnY18CpwDlDSbzM/gfbjiWgdMxtN6PQZqampYTYZu/YdPsL7i7YwMSOb9PW7qGFwXsdEHrusMxd1bkHtmnHRLlFEqrlSzeZy993BaalBhI4GJgWL3gVeC55nE7qWclQKsDloP79Y+6dBe0oJ/QG2Hj19FZwqO3oq7VhjFK/3FeAVgLS0tHABFVMKC52v1uxkYkY2HyzZwuEjhbRLrMfDgztxVa9kWjasE+0SRUS+E8lsrkTgSBAkCcBFwG+AKcBAQkck5wErg1WmAT8xs3cIXRTfE4TBTOBXZtYk6HcJ8Ii755jZPjMbAMwBbgL+WGRbo4Cngn+nHm+ME94LMWT9zgNMyshm0vxNbNp9iAZ1anJN7xSG90mhZ6vGOo0lIjEpkiOTJGBscN2kBjDe3aeb2b+Bt83sXkIX6G8P+s8gNGU3i9C03VsAgtB4EpgX9Hvi6MV4QhfzXyc0NfiD4AGhEBlvZrcBGwhNQz7mGJXVgdx83l8cOo01d20OZnBOh0QeGtyJSzq3pE4tncYSkdhmoQlRVV9aWpqnp6dHu4zvFBY6c9bmfHca62BeAW2b1+OaPilc3TuZpEYJ0S5RRAQzy3D3tHD99An4CrYx5yCT5mczaX42G3MOUb92TYb2PJXhfVLondpEp7FEpFJSmFSAg3n5/HPJt0xIz+arNTsxg7PbNef+i0/n0i6nkBCv01giUrkpTMqJu5O+fhcT0jfy/qItHMgr4LRmdbn/4o5c3SeF5MY6jSUiVYfC5CTbtPsQ787PZmJGNut2HqRefByXdU9ieJ9W9G2t01giUjUpTE6CQ3kFfLgsdBrri9U7cIcz2zbj7oEdGNztFOrGazeLSNWm33InyN2Zv2E3EzM2Mj1zC/ty80lpksA9F3bgmt4ptGpaN9oliohUGIVJKX275zCTF4ROY63ZfoCEWnEM6ZbE8D4p9G/TlBo1dBpLRKofhUkEDh8pYNayrUzIyObfq7ZT6NCvdVN+eF47hnRLon5t7UYRqd70WzCMcfM28Mv3l7P3cD7JjRP4yQXtubp3Cq2b14t2aSIiMUNhEkZSo4Tvvi/9zLbNdBpLRKQECpMwzu2YyLkdE6NdhohITNP3uYqISJkpTEREpMwUJiIiUmYKExERKTOFiYiIlJnCREREykxhIiIiZaYwERGRMqs23wFvZtuB9Se4enNgx0ks52RRXaWjukovVmtTXaVTlrpOc/ewn9yuNmFSFmaW7u5p0a6jONVVOqqr9GK1NtVVOhVRl05ziYhImSlMRESkzBQmkXkl2gUcg+oqHdVVerFam+oqnXKvS9dMRESkzHRkIiIiZaYwKcLMBpnZCjPLMrOHS1he28zGBcvnmFnrGKnrZjPbbmYLg8ftFVTXq2a2zcyWHGO5mdnzQd2LzKx3jNR1vpntKbK//rcCamplZp+Y2XIzW2pm95TQp8L3V4R1RWN/1TGzuWaWGdT1eAl9Kvz9GGFdUXk/BmPHmdkCM5tewrLy3V/urkfoVF8csBpoC8QDmUDnYn1+BLwcPB8JjIuRum4G/hSFfXYu0BtYcozlQ4APAAMGAHNipK7zgekVvK+SgN7B8wbAyhL+O1b4/oqwrmjsLwPqB89rAXOAAcX6ROP9GEldUXk/BmPfB/y9pP9e5b2/dGTyH/2ALHdf4+55wDvA0GJ9hgJjg+cTgQvNrLy/xzeSuqLC3T8Dco7TZSjwhod8DTQ2s6QYqKvCufsWd58fPN8HLAeSi3Wr8P0VYV0VLtgH+4OXtYJH8Qu8Ff5+jLCuqDCzFOAy4K/H6FKu+0th8h/JwMYir7P5v2+q7/q4ez6wB2gWA3UBXBOcGploZq3KuaZIRVp7NJwZnKr4wMy6VOTAwemFXoT+qi0qqvvrOHVBFPZXcMpmIbANmOXux9xfFfh+jKQuiM778VngQaDwGMvLdX8pTP6jpIQu/hdHJH1OtkjGfA9o7e7dgY/4z18f0RaN/RWJ+YRuEdED+CMwpaIGNrP6wCTg/7n73uKLS1ilQvZXmLqisr/cvcDdewIpQD8z61qsS1T2VwR1Vfj70cwuB7a5e8bxupXQdtL2l8LkP7KBon9BpACbj9XHzGoCjSj/0ylh63L3ne6eG7z8C9CnnGuKVCT7tMK5+96jpyrcfQZQy8yal/e4ZlaL0C/st919cgldorK/wtUVrf1VZPzdwKfAoGKLovF+DFtXlN6PZwNXmtk6QqfCB5rZW8X6lOv+Upj8xzygg5m1MbN4QheophXrMw0YFTwfDsz24GpWNOsqdl79SkLnvWPBNOCmYJbSAGCPu2+JdlFmdsrRc8Vm1o/Q+2BnOY9pwN+A5e7+9DG6Vfj+iqSuKO2vRDNrHDxPAC4CvinWrcLfj5HUFY33o7s/4u4p7t6a0O+I2e5+Q7Fu5bq/ap6sDVV27p5vZj8BZhKaQfWquy81syeAdHefRuhN96aZZRFK9JExUtdPzexKID+o6+byrgvAzP5BaKZPczPLBsYQuiCJu78MzCA0QykLOAjcEiN1DQfuMrN84BAwsgL+KDgbuBFYHJxvB/g5kFqkrmjsr0jqisb+SgLGmlkcofAa7+7To/1+jLCuqLwfS1KR+0ufgBcRkTLTaS4RESkzhYmIiJSZwkRERMpMYSIiImWmMBERkTJTmIiISJkpTEREpMwUJiIiUmb/P0sfE7JhUgPSAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "^losses not being reported quite right by `batchless_train`..." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the 2cwap%\n", "jaig aciph} araygay iblptoare josa7, pha]ptpjry iot, il) aydin t?e\n", "iruphy bol war############\n" ] } ], "source": [ "idx, VOCAB_SIZE, char_to_idx, idx_to_char = create_inputs(TRAIN_TEXT)\n", "sample(rnn, char_to_idx, idx_to_char, seed='the ', max_length=100)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u0000\u0000\u0000\u0000(ff{bomhfy tu tu fuphays\n", "\n", "argzry ip{bo$#############u2/j,q)er1.x\n", "flphy pha) xrtphuipariphcip biutaly\n" ] } ], "source": [ "idx, VOCAB_SIZE, char_to_idx, idx_to_char = create_inputs(TRAIN_TEXT)\n", "sample(rnn, char_to_idx, idx_to_char, seed='\\0'*BPTT, max_length=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## with PyTorch's RNN layer" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class PyTorchRNN(nn.Module):\n", " def __init__(self, vocab_size, hidden_size, n_fac, batch_size):\n", " super(PyTorchRNN, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.vocab_size = vocab_size\n", " self.n_fac = n_fac\n", " \n", " self.embedding = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.RNN(n_fac, hidden_size)\n", " self.l_out = nn.Linear(hidden_size, vocab_size)\n", " self.softmax = nn.LogSoftmax(dim=-1)\n", " \n", " self.init_hidden(batch_size)\n", " \n", " def forward(self, inputs):\n", " bs = inputs[0].size(0)\n", " # dynamic batch sizing\n", " if self.batch_size != bs: self.init_hidden(bs)\n", "\n", " inputs = self.embedding(inputs)\n", " output, hidden = self.rnn(inputs, self.hidden)\n", " # detach from history of the last run\n", " self.hidden = hidden.detach()\n", " output = self.l_out(output)\n", " output = self.softmax(output)\n", " \n", " return output.view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs):\n", " # 1 RNN layer\n", " self.batch_size = bs\n", " self.hidden = torch.zeros(1, self.batch_size, self.hidden_size).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0m 8s (0 0.00%) 4.2960\n", "Epoch 0 sample:\n", "the bgqdtw6#i-[mngryn}z#)<|q!el{hzfpb%@qzc $a4d\\w6qr+-y! e\n", "1m 3s (10 1.00%) 2.8015\n", "Epoch 10 sample:\n", "the los bgbhs wbeannuonetsncl po wic av vn so hebadgaf cs -g mr\n", "'nircg,uorite o* wi wd so wasiheuus bo \n", "1m 58s (20 2.00%) 2.4843\n", "Epoch 20 sample:\n", "the bbt\n", "e burossbe ron sarelliny om uzeacran\" fumy\n", "2m 53s (30 3.00%) 2.3042\n", "Epoch 30 sample:\n", "the f.\n", "anilhis in inaradmethe thagte me ses liv\n", "g_ pod thos th utpes bo paredolisaca oused _ _nsentaia \n", "3m 48s (40 4.00%) 2.2106\n", "Epoch 40 sample:\n", "the ns asoisoremr arty. ibohent . _defma:s pitos thathe d pmodeo_e r ly pifti- ialis' bouy a's af uf y m\n", "4m 44s (50 5.00%) 2.1274\n", "Epoch 50 sample:\n", "the uppepin a, ok amb rar, nsouyoi\n", "a- ated y\",uns ove asd h\" hamitole \"i wity sus alo rasenge,pokerots5 \n", "5m 39s (60 6.00%) 2.1382\n", "Epoch 60 sample:\n", "the xiwattitich chawaneaneaghagk \n", "-sengamcam , the\n", "ga moroicorhes\n", "lis neundur mesthe\n", "\" ouknanmastong iri\n", "6m 34s (70 7.00%) 2.1227\n", "Epoch 70 sample:\n", "the so tou laiwin southighi,\n", "i, sol tnd*mor i'siicanoann\n", "* an_ en yog oc beroge\n", "7m 29s (80 8.00%) 2.1580\n", "Epoch 80 sample:\n", "the pl arnathesrocoutinitoby om t' teumo's.-\n", "sp\n", "amouth om or or ondingvelhel ddetinger anlietien hal ole\n", "8m 25s (90 9.00%) 2.1470\n", "Epoch 90 sample:\n", "the do dalmondouti\n", "w. p2 fiitcr pn e sroricr\n", "c\n", "s. se a ncsrcke-\n", " t\n", " natookep. rr a wa ealirhenery ac k\n", "9m 20s (100 10.00%) 2.0815\n", "Epoch 100 sample:\n", "the ou difocllina jarbch blyolivge aencwnthon\n", "mc\n", "fa ea wastnst r, sso sourraiken-*psovyentente ee\n", "peatya\n", "10m 15s (110 11.00%) 2.0790\n", "Epoch 110 sample:\n", "the jo hacrrck ancatisnet ** we w nw collocanseno d t ta tipeile fice sarkikdake d y ty stamiss\n", "pe fis\n", "11m 10s (120 12.00%) 2.0864\n", "Epoch 120 sample:\n", "the fo hisrisonndelo, o\n", "win, -, bo, vamane, wat'shanean's\n", "\n", "ptietdech od soucait|#wyw de metmilaseawasha\n", "\n", "12m 6s (130 13.00%) 2.1082\n", "Epoch 130 sample:\n", "the d, on rollighane ss kn braban im dud ru asothormaga and is ah amsir\n", "lufo e si s. wh wh, ho foxiselce\n", "13m 1s (140 14.00%) 2.0286\n", "Epoch 140 sample:\n", "the fu rryglealram (a \n", "le al\n", "\n", "on on th thethe mo *, ly he harnoongeredptrto chempcas\n", "athendiig an sx pl \n", "13m 56s (150 15.00%) 2.0348\n", "Epoch 150 sample:\n", "the fe p, murlinhin me realiileanach pe for ave he\n", "fo fhanki).\n", "\n", "y fo veurold * gite al unt, scenslarta f\n", "14m 51s (160 16.00%) 2.0806\n", "Epoch 160 sample:\n", "the n, werte cese th thetta t, aphaulatldpspstreveas.\n", "swimn ic i st lo dow/ \" be dh wan_ialeute s anle-s\n", "15m 47s (170 17.00%) 2.0448\n", "Epoch 170 sample:\n", "the me tretheahladye hemolloble ander- r. mm veboof\n", "it' flltellendew 's asste aruer 2k ze quintinkin in \n", "16m 42s (180 18.00%) 2.0311\n", "Epoch 180 sample:\n", "the & pimalmacariace it so sereove inirtiofel as de we wheneer i's us co lese\n", "w,\n", "es\n", "sh\n", "meangengictusenr\n", "17m 37s (190 19.00%) 2.0079\n", "Epoch 190 sample:\n", "the =, g. go tarealigingtlotsoftheswupparderioun tz wm wardaudlis ocloralawcy 'pea yacibamce iecoro)\n", "\n", "p\n", "\n", "18m 32s (200 20.00%) 2.0034\n", "Epoch 200 sample:\n", "the he hechack ad mo vo moimonr withy.\n", "s.\n", "\n", "h\n", "\n", "boiboug\n", "a \n", "2|-0460||15. s pl offontistis is oi tut ho bolb\n", "19m 28s (210 21.00%) 1.9818\n", "Epoch 210 sample:\n", "the ho thethe le d_ am am as on sore g\n", "mu marrarw, w, wl \"leyomy in in veares fo por ar sa rartorflofesh\n", "20m 23s (220 22.00%) 1.9824\n", "Epoch 220 sample:\n", "the wi be leninns wt he, ay usctoogisian, is inthes at be be erim ll yl yk ptemurvarek, l. w.\n", "\n", "vi wi amu\n", "21m 18s (230 23.00%) 2.0133\n", "Epoch 230 sample:\n", "the fe wistisn w: nounous bt bat at sungatw, e, at acts \"w y' sordoug:h \"d -g alp by br becrntlc bl buat\n", "22m 13s (240 24.00%) 2.0461\n", "Epoch 240 sample:\n", "the mo allebrallires. s. thivempomf ** ** forare\n", " an thatha bredrak ar she ve a'luavark c ttel blybiuthe\n", "23m 8s (250 25.00%) 2.0577\n", "Epoch 250 sample:\n", "the vi trutkatday py prfauricrit _, th th ph shaavaunounpraps -p ipracus es ph hischeowist n songaughar\n", "24m 4s (260 26.00%) 2.0384\n", "Epoch 260 sample:\n", "the wh w_ garncrmubr blly ly us laknave tas ma motvowrvtulleltrmeakbownos su sin'tnedr my forlmorverine \n", "24m 59s (270 27.00%) 2.0107\n", "Epoch 270 sample:\n", "the f lireaneanelintict ws tise fe loenofo fe fic fic ip ng sos ok c. sfetlad ss do , d,cvinghe butbo. \n", "25m 54s (280 28.00%) 2.0481\n", "Epoch 280 sample:\n", "the th wheverman as k\n", "us\n", "ove ve thiths yeve wof po we tareahima fa farpatp-hecveseazeate ty ah at's th t\n", "26m 49s (290 29.00%) 2.0256\n", "Epoch 290 sample:\n", "the ch th usmastandiling d l, ve f onod 20?\n", "*. ho hit fo as ompemly lh lokgove/se\n", "lst it wotrist sa pa\n", "27m 44s (300 30.00%) 2.0642\n", "Epoch 300 sample:\n", "the aw he ''ven wh wo worker-ca--rot oo vutvjnp--s\n", "so whovye ize hh hotsocthoveollond,\n", "s.k.\n", "\n", "\n", "\n", "\n", "\n", "\n", "28m 39s (310 31.00%) 2.0442\n", "Epoch 310 sample:\n", "the mo bfeboandselibe ly dyhe teetielis.\n", "\n", "i if ******zhaatovarkatla om om he h'satodeaclaklulhste ay ou \n", "29m 34s (320 32.00%) 2.0835\n", "Epoch 320 sample:\n", "the fo wong al *_ \n", "- reebelbalya sor is is b thiccishares co sodracvec awink platid iit it ie anle kn a\n", "30m 28s (330 33.00%) 2.0452\n", "Epoch 330 sample:\n", "the fo keckdhelotrthychelbolandandle\n", "lett ite el bigenow ig or we calhen la ou wostou chyofonten -- 218.\n", "31m 23s (340 34.00%) 2.0161\n", "Epoch 340 sample:\n", "the ab orearjopata da onn *gomaicoriwang dicangon fn fem al ohereverudryonyo de le mormowlosnz\"che rtin\n", "32m 18s (350 35.00%) 2.0625\n", "Epoch 350 sample:\n", "the (bu glopeay, ak ah here avneveenteraalayke walonewttlovexvo felowlev, phach'cvecha andimwimk,e te a\n", "33m 13s (360 36.00%) 2.2247\n", "Epoch 360 sample:\n", "the me medreenithingratheetidtinglegoupeinanalir iniwer tvrifrifeplinntameas\n", "arteasher al\n", " aebeatrathawh\n", "34m 8s (370 37.00%) 2.1703\n", "Epoch 370 sample:\n", "the ** afpo\n", "lothavent. tseareturuluptiysineane.\n", "t.\n", "wh mr t tay atl'sbomour ar ce ce ckiuiic is is iatof\n", "35m 3s (380 38.00%) 2.0762\n", "Epoch 380 sample:\n", "the 19.16.15, ah afreong l arm er illonkbyoviigaupen os o oviokingiis had t\n", "\n", "n cy ar iewing il ha uyeuo\n", "35m 58s (390 39.00%) 2.1481\n", "Epoch 390 sample:\n", "the (\n", "\n", " t. the ge to iondoquaive an in ed blanaisind of anvar* ie welwaswaeste't atilbredaded ) r turid\n", "36m 53s (400 40.00%) 2.0796\n", "Epoch 400 sample:\n", "the (b1 qualluvistust pd came bo y ni bub mon astofs f afoalsom's\n", "\n", "f -\n", "clpblaon \n", "\n", "37m 48s (410 41.00%) 2.0811\n", "Epoch 410 sample:\n", "the fo \"lfurendekris st smarm-tho bo dac ppreavesses\n", "\n", "\n", "38m 43s (420 42.00%) 2.1144\n", "Epoch 420 sample:\n", "the fo or, injuprousicnerive ve bimlideedeeonis io anetieritrshong\"bym nw be oulideenether,\".:>..\n", "\",\" s \n", "39m 38s (430 43.00%) 2.1024\n", "Epoch 430 sample:\n", "the f s inlowef go ne dud it ig the ba\n", " ate se is cae ad ap ay th th le osoo e he /\" fe me\n", "meande th t\n", "40m 33s (440 44.00%) 2.0943\n", "Epoch 440 sample:\n", "the w. h'se\n", "pt'pe de s,.\"__\n", "\n", "p\n", "\n", "/ ****, wncthethe ps ss thsoys ingom'si'getpe toinge, ss sf ingird_w.\n", "\n", "\n", "41m 28s (450 45.00%) 2.0570\n", "Epoch 450 sample:\n", "the s 1 mlutaldaeedens a theharomos thitherree erotelthtr aroplos the'se's th therilrditerthatryics bl l\n", "42m 23s (460 46.00%) 2.0973\n", "Epoch 460 sample:\n", "the & ayturthiigis ygut. \n", " i k***y fe fh: wollovethas un bellinsor *vpyhdulritn an an an tontoss at am \n", "43m 18s (470 47.00%) 2.0657\n", "Epoch 470 sample:\n", "the 23 18 /2 b@yh yt it efrefrelt tt touteo font we ne o sout ind\n", "orper ss thitrjaymerion g ct t va de \n", "44m 13s (480 48.00%) 2.0498\n", "Epoch 480 sample:\n", "the & livirm blerarrardathoundaviar ak at 2/2011 ba bm so we waroprombo gommnrmdom @ gimoingo.tiuti gsg\n", "45m 8s (490 49.00%) 2.1247\n", "Epoch 490 sample:\n", "the 2 2 -pe/005 heazin re ar le te thetve tenting st sur zam at so toutit rtecabtmerr th thitantthettin\n", "46m 3s (500 50.00%) 2.0600\n", "Epoch 500 sample:\n", "the fi nic'l without ovod of on thtinulie iom om at finp ht hilin *\n", " frrorn th thefoonarrybus n ysuecoa\n", "46m 58s (510 51.00%) 2.0248\n", "Epoch 510 sample:\n", "the (a om isrernelingshico me meurruisnisrapewees an an youyoved*eoumor tr tthevo op\n", "ol opy pugpeplparm\n", "47m 53s (520 52.00%) 2.0257\n", "Epoch 520 sample:\n", "the 01 2008 y, bur..\n", "u. his *rorrmathur as s, chachuicurerertiuticerce, ah shtrefr.\n", "iti \" ro r trearsas\n", "48m 48s (530 53.00%) 2.0304\n", "Epoch 530 sample:\n", "the @ 2 \"hilerte vunrinding so sll da fariat at ausen amleime ellencous sh yiteive t t or\n", "on wo, ch che\n", "49m 43s (540 54.00%) 2.0456\n", "Epoch 540 sample:\n", "the 2/18 _3\n", "\n", "chopnernuriseo a , ad inronl eom\n", "em\n", "th tanemnemehaoumyghitn esoeuodus s neineine u),\"**hymo\n", "50m 38s (550 55.00%) 2.1504\n", "Epoch 550 sample:\n", "the jo z- vf rf buke s. __ h_ \n", "o \n", "e te lozliwn me mafma.s th antistranban.\n", "\n", "******erme\n", "m l al an l\n", "soma\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "51m 33s (560 56.00%) 2.1253\n", "Epoch 560 sample:\n", "the @ \n", "'v_:/\" lod th dy 'scouokeordotuey ***** ** th the 200 20)\n", "**\n", "* othivetu te u quituut ay ie ou te\n", "52m 28s (570 57.00%) 2.0786\n", "Epoch 570 sample:\n", "the || ||| || a| re-waddad th ferfev r anallaharif fare pis bo lissifardatracaisu du duanedsok hakio, b\n", "53m 23s (580 58.00%) 2.0672\n", "Epoch 580 sample:\n", "the wo eijumeafte anned ef he fed lf (sleslandelyoyo yo1\n", "ch he meoy e cacth\n", "orerthrugit vokelleglling t\n", "54m 18s (590 59.00%) 2.0922\n", "Epoch 590 sample:\n", "the fi flrrluy. of hereoser ce on ck wn kr-plora\"englk)\n", "towwt walepat tsdthe pinpungnbsarnerd bek _g's\n", "'\n", "55m 13s (600 60.00%) 2.0982\n", "Epoch 600 sample:\n", "the to sou on mbsazeare @ j'th'r lithaxcsstt _ re dt of oi be s jant ut yomdemuitp ttthine piedeis me\n", "56m 8s (610 61.00%) 2.1456\n", "Epoch 610 sample:\n", "the @ n y.\n", "ed me ins ack ff lryrs, @ ure we whaboa lame tl/\"mc or keestighourda; a_ in le the ce ce in\n", "57m 3s (620 62.00%) 2.0980\n", "Epoch 620 sample:\n", "the @ 2005.08.'s.eroals ma faredowt w**lsal.\n", "st tt t, shegheas te we(th ns\n", "nt j\n", "slas\n", "\n", "\n", "57m 58s (630 63.00%) 2.1844\n", "Epoch 630 sample:\n", "the 11 18 d*equabovtow hcpcopce coekee_** \" ow th the sinsonckd\n", "yr\n", "\"rp dg--l----2011018.18 ftefecs cld\n", "\n", "58m 53s (640 64.00%) 2.1013\n", "Epoch 640 sample:\n", "the @ \" p\"byeraca bu, hu hay 2011. isuuosuoliysidinging thec, cuccamandingthsl dnomulmal io e(th t trat\n", "59m 48s (650 65.00%) 2.0895\n", "Epoch 650 sample:\n", "the @ 17.\n", "ns rw in h whe ne n ln destexcenretrherthpspcopeulaula cay ow in on theahoush orageh el erd, \n", "60m 43s (660 66.00%) 2.1343\n", "Epoch 660 sample:\n", "the @\n", " ')\n", "\".\n", "e!\n", "\n", "is \n", "y allarmaht ypangaubicors, sicaxchthhirlywh dhofarkinkrnges the\n", "mecespresuriaril tl\n", "61m 39s (670 67.00%) 2.0944\n", "Epoch 670 sample:\n", "the jo |-t| each ur anda2bouse****o, c, h romest lr s, oreove oenonredd2 pk \" adleed estey ol of tf wl w\n", "62m 34s (680 68.00%) 2.0994\n", "Epoch 680 sample:\n", "the ki b, by\n", "\n", "g o bupy,\n", "tt stnt t vifriot olleacaas, whis tted's 3: is it arily li eave ve ve rowh\n", "i tit\n", "63m 29s (690 69.00%) 2.0880\n", "Epoch 690 sample:\n", "the @ | l- hicahestaytrrtdcebcrucris ar ar ehit'tweb ishebhigdiesdutre ti b \n", "ga in it (at in f if pand\n", "64m 24s (700 70.00%) 2.0826\n", "Epoch 700 sample:\n", "the ** *.\n", "\n", "65m 19s (710 71.00%) 2.0980\n", "Epoch 710 sample:\n", "the 2018.18001 te it cogrinw pat ea liulerknd ly calnotoutlyapeu-hiset. wiwein iabely se sdetlit's \"sg\n", "66m 15s (720 72.00%) 2.1777\n", "Epoch 720 sample:\n", "the mi ,\"l,\n", "\n", "(\n", "\n", "\n", "6 ende\"ra.is p ot of sf r, on of lupra ke artongglutwestrofdlredelezinsindst me f tf w\n", "67m 10s (730 73.00%) 2.1465\n", "Epoch 730 sample:\n", "the @ | *-h's th thetho kall __ . i. ise cetres ud s be they *k aheave ve ke te titwand\n", "aede, el at hid\n", "68m 5s (740 74.00%) 2.1024\n", "Epoch 740 sample:\n", "the (| |- he tfeckecanceum pe malbae-fooviss rnenisi th ip in in on thy goutting naln bu buint wowebhlov\n", "69m 0s (750 75.00%) 2.0838\n", "Epoch 750 sample:\n", "the @ | in mn wittitrict co n \n", "\n", "sa rmydum awhury ick'sy h ch pheracpaye ye hof er catithiut er cack's \n", "69m 56s (760 76.00%) 2.0982\n", "Epoch 760 sample:\n", "the whio,e sye s, by lavude t trcprcpec, he hersade tr thesteyupucwis hevochoclyngwint is\" ra coctacidr\n", "70m 51s (770 77.00%) 2.0499\n", "Epoch 770 sample:\n", "the @ **\n", "**** a te titho\n", "vorreelilin leneebeebdebverledaw***** \"\n", "\n", "scoclaslisintingy mp th th** a\n", "s adow\n", "71m 46s (780 78.00%) 2.0755\n", "Epoch 780 sample:\n", "the ju ',\n", "_prpgr tiving)@\n", "d myiffetiun camoum stulnul plesondmer or by w song\n", "-p_uestysts\n", "fenten-wn wihe\n", "72m 41s (790 79.00%) 2.0775\n", "Epoch 790 sample:\n", "the @ **n_._\" tont.il'r ati sl sretint ou sathehellrerole mr cthhethel s ( (ally\n", "tr\n", "ge atmety\n", "tesilr fe\n", "73m 36s (800 80.00%) 2.0811\n", "Epoch 800 sample:\n", "the @ o ?**. n andor, lyeita w mitested as uo sive upenfe ar is anduambr\"vingenteangrdienog, si lisis\n", "74m 32s (810 81.00%) 2.1021\n", "Epoch 810 sample:\n", "the @ o.>a.\n", "is _s pusiae ireisondonw mm on wngwongnet e an andusol on of ser\n", "ec eisu s masmutvilheryou\n", "75m 27s (820 82.00%) 2.2306\n", "Epoch 820 sample:\n", "the &b ol orrply ta the dobeofev)rossossiok\n", "of ald ke teicouwerdpat g thar ihelyere & 't, carod*m\n", "toft,\n", "76m 22s (830 83.00%) 2.1387\n", "Epoch 830 sample:\n", "the @\n", " **\n", "\n", "uccucter berink\n", " @ nd 'sushe; me areeanexif 'e col | dee\n", "iy hre berqu mdc gh arisa r he so f\n", "77m 17s (840 84.00%) 2.1555\n", "Epoch 840 sample:\n", "the wh hlfancadcalaxd in theg, andangang bevadyonyaecell pr dill'llev-rrllassally thethefe.te teitolllad\n", "78m 12s (850 85.00%) 2.1246\n", "Epoch 850 sample:\n", "the ji \n", "an lejbermed syangheus on _ h uv y\n", "or\n", "\n", "s \n", "rbounghcheohis we oa ou on o su sherve ire\n", "\n", "s us aw\n", "79m 7s (860 86.00%) 2.1331\n", "Epoch 860 sample:\n", "the @ **.a), 'sa an in come*c y p rt t is uchis, we: les,entk* k's \n", "rsels lotorsesdere, me \n", "es th ve. \n", "80m 2s (870 87.00%) 2.1074\n", "Epoch 870 sample:\n", "the @ o\n", "tt rutte i dibyll jicngfessandre- inarourepingpre s f\n", "\"calt s aczzom wh ir io mo mye god om he\n", "80m 57s (880 88.00%) 2.0635\n", "Epoch 880 sample:\n", "the (wi h\n", " le medcont s ly ory pryluscacca f\n", "whint lm sorsalckked ond we wrthron\n", "ctictictibesoek?** ck j\n", "81m 52s (890 89.00%) 2.0596\n", "Epoch 890 sample:\n", "the @ s\n", "\"nixp e itc ms d, tue iefim j's lromperceaceire \"getie iel on an wre w\n", "\n", "wo of fitistis ni nin ir\n", "82m 47s (900 90.00%) 2.1095\n", "Epoch 900 sample:\n", "the @ 18.xter omas, 's ar m. ms of ret0\" ardon fe beps. ralond f lf llaon on @ thend il s _\"dertenkenqu\n", "83m 42s (910 91.00%) 2.0934\n", "Epoch 910 sample:\n", "the @w.*sstomtomastaetatt'thard rompat o oad avir- wes ssh oreelober ayaryorkid aacaycy youtorurhelld(l\n", "84m 37s (920 92.00%) 2.1103\n", "Epoch 920 sample:\n", "the @ \"_\"tarturiolielint ng meeeelt tw st\n", "me sutureute the dg ng of tt sudujlit tdrrichuluilealerert jte\n", "85m 33s (930 93.00%) 2.1305\n", "Epoch 930 sample:\n", "the @ \"_\"_ar of aftaiqu thet tt the geugtireand s siosiotiofls amayer is ch ik t_ s boialyalledat stosto\n", "86m 28s (940 94.00%) 2.1392\n", "Epoch 940 sample:\n", "the @\n", "\n", "_0_ 018 hezth e fi fi. i ioder (2to2\\! e.\n", "\n", "eut st kens\n", "\n", "fet atdiedi pubing re ritaxul\n", "halmu\n", "dere\n", "87m 23s (950 95.00%) 2.2243\n", "Epoch 950 sample:\n", "the | n asiass aroronofrewifgrengarnd f hourend** amecon andemay. el, ng\n", "\n", "8alneve vov yhovithevy\n", "ly sint\n", "88m 18s (960 96.00%) 2.1974\n", "Epoch 960 sample:\n", "the @ **_0001.\n", "th**ha beabe\n", "\n", "entres esth st ng has aste tsts.1).\n", "iappepeyy mathe yopyofre ce iese cooci\n", "89m 13s (970 97.00%) 2.1068\n", "Epoch 970 sample:\n", "the @\n", " 20s/ ihotrfer nt of toe om\n", "oucg th on oy is it in in on an a, of are xpraminoande_\"y\n", "d bk ch sher\n", "90m 8s (980 98.00%) 2.0956\n", "Epoch 980 sample:\n", "the @ \".\"pmytthes st te ulione ppedt buthe\n", "\n", "ti siceeca cavour. mw twebeerane ng aesiore ry by mendes ant\n", "91m 3s (990 99.00%) 2.1618\n", "Epoch 990 sample:\n", "the @ 2+ f\n", "ip\n", "yore oressos in beobwins er juprepo suns, 9am th sulhandyydredemugoupustheccound y.uw, pa\n", "Training time: 5512.38s\n" ] } ], "source": [ "prnn = PyTorchRNN(VOCAB_SIZE, N_HIDDEN, N_FAC, BS).to(DEVICE)\n", "optimizer = optim.Adam(prnn.parameters(), lr=0.005)\n", "all_losses = train_loop(prnn, optimizer, TRAIN_TEXT, criterion=nn.NLLLoss(), epochs=1000)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4XOWZ/vHvo14sWZIl20IucgUbG2MQYLpDW+ywlEASUoEfifODZEPKJkv6EpKwWTYklBAWAgQIC04wsA6h9xaM5W5chatcZfUuzcy7f8xIlmXNSLZHGp/R/bkuXWhmjmeew5FuPfPMO2fMOYeIiMSXhFgXICIi0adwFxGJQwp3EZE4pHAXEYlDCncRkTikcBcRiUMKdxGROKRwFxGJQwp3EZE4lBSrB87Pz3fFxcWxengREU9asmTJPudcQW/bxSzci4uLKS0tjdXDi4h4kplt7ct2GsuIiMQhhbuISBxSuIuIxCGFu4hIHFK4i4jEIYW7iEgcUriLiMQhT4e7zx/gL4u34w/oowJFRLrydLgv3lLN9xesZNGmyliXIiJyVPF0uLf6/ABUNrbFuBIRkaOLp8O9YxxT3aRwFxHpytPh7usI98b2GFciInJ08XS4q3MXEemZwl1EJA7FSbhrLCMi0pWnw33/zF2du4hIV54Od38gAGgsIyLSXa/hbmZpZvahma0ws4/M7JYetkk1s/lmVmZmi8ysuD+K7a6jc6/RWEZE5AB96dxbgfOcczOAE4GLzWxWt22uB6qdcxOB3wK/jm6ZPeuYuTe0+mjzBQbiIUVEPKHXcHdBDaGLyaGv7idzuQx4JPT9U8D5ZmZRqzIMn39/GTUazYiIdOrTzN3MEs1sObAXeMU5t6jbJkXAdgDnnA+oBYb1cD/zzKzUzEorKiqOrHI44IRhWjEjIrJfn8LdOed3zp0IjAJONbNp3TbpqUs/6FSNzrn7nXMlzrmSgoKCQ6+2G1+XcK/SihkRkU6HtFrGOVcDvAlc3O2mcmA0gJklAUOBqijUF1HHahnQWEZEpKu+rJYpMLOc0PfpwAXAum6bLQSuCX1/FfC6c67fT7J+QOeucBcR6ZTUh20KgUfMLJHgH4O/OOeeM7OfA6XOuYXAg8BjZlZGsGO/ut8q7qLrzF3LIUVE9us13J1zK4GZPVz/0y7ftwCfjm5pvfMFHCmJCSQlmt6lKiLShcffoepITDByM1I0lhER6cLT4e7zO5ISjJyMZI1lRES68HS4+wMBEhONvMwUnV9GRKQLT4e7L9DRuado5i4i0oWnw71j5p6Xkax3qIqIdOHpcA927gnkZKRQ29yOz6+Th4mIgMfDff9qmWQAapvVvYuIgMfDvWPmnpuZAujkYSIiHTwd7v5AoHOdO+gTmUREOng63H1+d2C4a8WMiAjg8XDvnLlnBmfueiOTiEiQp8O9c+Ye6tx1CgIRkSBPh3tH556RkkhKUoJm7iIiIZ4Od18gQFJCAmbB5ZCauYuIBHk63Ds6d4DcjBQthRQRCfF0uPsCjqTE/eGuj9oTEQnydLgf0LlnJutDskVEQjwd7h3ncwfIyUjRUkgRkRBPh3vXzj0vI4Wa5nYCgX7/XG4RkaOep8O9Y7UMQE5GMv6Ao77FF+OqRERiz9PhfkDnnqnzy4iIdPB0uHe8QxXQu1RFRLrwdLh37dxzQud01xuZREQ8Hu5d17mPycsAYFNFYyxLEhE5Kng63Lt27sOGpDIiO5W1u+piXJWISOx5Otx9/v2rZQCmFGazRuEuIuLtcO/auUMw3Mv2NtDq88ewKhGR2Os13M1stJm9YWZrzewjM7uph21mm1mtmS0Pff20f8o9UNfVMhAMd1/AUba3YSAeXkTkqJXUh218wHedc0vNLAtYYmavOOfWdNvuHefcJdEvMbzunfvUwiwA1u6q5/hjhg5kKSIiR5VeO3fn3C7n3NLQ9/XAWqCovwvrjXPuoM69eFgmqUkJelFVRAa9Q5q5m1kxMBNY1MPNp5vZCjN7wcyOD/Pv55lZqZmVVlRUHHKxXXWcQiaxywuqSYkJHDsyS+EuIoNen8PdzIYAC4BvOee6p+dSYKxzbgZwN/BsT/fhnLvfOVfinCspKCg43JqB4HllgM517h2mjMxm7a46nNMJxERk8OpTuJtZMsFgf9w593T3251zdc65htD3zwPJZpYf1Uq78Yda964zd4Cpx2RT3dTO7rqW/nx4EZGjWl9WyxjwILDWOXdHmG1GhrbDzE4N3W9lNAvtzhcK96Ru4T6lMBtAoxkRGdT6slrmTOBLwCozWx667ofAGADn3H3AVcANZuYDmoGrXT/PRfz+njv347qsmDnvuBH9WYKIyFGr13B3zr0LWC/b3APcE62i+iJc556dlsyo3HS9U1VEBjXPvkN1/8z94F2YUpitsYyIDGqeDffO1TIJBz+pmFKYzZZ9jTS36TQEIjI4eTbcw62WgeA7VQMO1u+pH+iyRESOCp4N986Ze+LB4d5x6oHl26oHtCYRkaOFZ8M9EKFzH52XwcThQ3jpoz0DXZaIyFHBs+EebrVMhznTRrJocyWVDa0DWZaIyFHBs+EeabUMwJxphQQcvLxG3buIDD6eDffeOvcphVmMHZbBC6t3D2RZIiJHBc+Guz+0FLKnmTuAmTFnWiHvl+2jtql9IEsTEYk5z4a7zx+5c4fg3N0XcLyyVqMZERlcPBvukda5dzhh1FCKctJ5YdWugSpLROSo4Nlwj7TOvUNwNDOSdzbuo75FoxkRGTw8G+69rZbpMGf6SNr8AV7RqhkRGUQ8G+69rZbpMHN0LuPzM3ngnc36dCYRGTQ8G+69rZbpkJBgfOO8iazdVcera/cORGkiMojsrm05KhtHz4Z7Xzt3gEtnHMPYYRnc+dqGo/IgDEbt/gD79O5h8bjHPtjKrNte44sPLuKjnbWxLucAng33vqyW6ZCUmMDXPzGR1TvqeGO9uvdYCwQc1z28mPN/8xZVjW2xLqdftLTrdNPRsLOmmc8/8AH/9tRK/rJ4Ox9XNMS6pE4ry2u49W9rmF40lDU767jk7nf517+uOGoWb3g23Pevc+/bLlwxs4hRuenc+VqZuvcYe+i9zbxbto/a5nbufn1jrMuJumXbqplxy8ssWFIe61KizjnH798oG7Am6bevbGDxlipe/Gg331+wkvN/8xbXPvwhZXtjG/K1Te3c+PhSCrJSefT/ncqb3/sE884ez7PLdvCN/1nW2XzGkmfDvbNzj7AUsqvkUPe+YnsNC1fs7NO/2V7VxC//voY7Xl6vPwhRsn53Pf/50noumDKCq08ZzZ8/2MrWysZYlxVVTy0pp9UX4EfPrmLd7vj6RLCFK3Zy+0vrufHPS9nYz5+XsKmigQVLy/ny6cUs+8mFvPqdc/m3i49jyZZqLv7d2/z8b2uojsEzP+cc3/3rCvbUtXDP52eSm5nC0PRkfjB3Cj+/bBpvbajgV8+vHfC6uvNsuB/KzL3DlSeN4tgRWdz05HK+/vhSdtQ0H7RNmy/Am+v38rXHSjn39jd44J3N3PV6GY99sDVqtQ9WrT4/Nz25jOy0JP7jyul8+8LJJCUkcPtL62NdWtS0+wO8sHo3Z0/KJystmRv+vLRfnqa/vaGCe17fyM4efob7S0V9K/++8COmFWWTmZrEDY8vpanN12+P97tXN5KalMgNsyeQkGBMHD6EG2ZP4I3vzebTJaN4+P3NnPnr1/nl39ewt66l3+roqqHVx01PLufVtXv4wZwpzByTe8Dtnz9tDNeeUcyD725m/uJtA1JTOL1+QPbRqq+rZbpKSUrgf79xJv/91ibufbOM19bt4ZxJBQzPTqVgSBplFQ28uW4v9a0+cjOSuWH2BL5w2lh+/Oxqbn1uDdOKhnJSt4Mpfff718tYt7ueB68pIX9IKgBfPXscd71exlfPrmHG6Bya2/w0t/vJy0yJcbWH5/2PK6lqbOOLs8aSk57M5/+4iJsXrOKez8/ErO8/q5GUbqniK4+W0uYLcMcrGzjvuBF8cdYYzp1cELXH6MnPFq6msdXPbz9zIhX1rXzhwUX8+JnV/OYzM6L+uOt21/G3lTu54dwJnT8rHfKHpHLbp07gujPHce8bZTz47mYe+cdWSsbmMrUwm+MKs2ls9bF+Tz2bKhr45AnH8KVZY6NS042PL2XLvkb+9aLJXHdmcY/b/fiTU9i0r5EfPbMa5+AzJaNJOIScihaL1bihpKTElZaWHva/f/i9zdzytzUs/+mF5GQcehCUVzdxxysbWL2jlr31rdQ0tTMsM4ULpozgouNHcObEfNKSE4HgfO2Se96h3ed47ptnHfTD1lW7P8DWykYmDs867H2LRy3tfmbd9hqzxg3jvi+d3Hl9Q6uPc//zDdKSE0lNSmBzZSMGXH/WOL5z4bGkpyTGrujD8L2/ruDF1btZ/OMLSEtO5L63PuY/XljHVSeP4tbLph3x/mze18in7n2PnIwU7v7cTF5YvYv5i8vZ19DKpOFD+MrZ47h0RhENrT4q6ltpavNRkJXKiOy0zp/nvnDO8XFFI75AgKy0ZBZvruJb85fz/YuP5cbZEwG489WN/PbVDXzt3PFce0YxhUPTj2jfuvraY6W8X1bJO//2iV5/v7dWNvLwe1tYtq2adbvrafUFG7+h6cnkZCSztbKJB68p4fwpI/r8+E1tPh54ezMvr9lNx/h8U0UD2enJ3HX1TE6fMCziv69raecrj5Ty4eYqTh6byy8un8ak4UPYU9/Krppmhg1JZVx+Zp/r6crMljjnSnrdzqvh/sd3NvGLv69l1b9fRFZa8hHX0+rzk5yQEPYv7OodtVz5h/cZl5/J5TOLOGtiPlMLsw/Yfl9DKzf+eSkfbqnil1dM4wunHXm3EC8WLCnnu39dwf989TTOmJB/wG3PLtvBPW+UMbFgCMcVZrGrpoX5pdsZk5fBLZceT0lxbucxbvX5KdvbwPaqJo4/Ziij8zI676el3c/K8lrKq5vYU9dKZUMrU4/J5uJpI8lI6f8nqa0+PyW/eJWLpo7kN5+ZAQRXBv3u1Q3c/UYZk4dnce8XT2JCwRD8AUdNUxu5GSkH/Az5A463N1Swfk89++pb2dfQSlZaMtOLhjJh+BC++5fl1LX4ePqGMygOhUObL8DfV+3k/rc3s3ZX+Bn/sMwUph6TzfHHDKV4WAYfVzSwsryWTfsamVqYzdmT8jlpbC4fbKrk6aU7DnrR8oRRQ3n6hjNISkzorPWmJ5fx3MpdmMGZE/KZfWwBY4dlMnZYBsXDMklJOrTJ776GVh59fwt3vV7Gty+YzE0XTDqkf+/zB9ha1URWahIFWam0+gJ8+r5/sKWykYXfOKvXQA0EHM8s28HtL61nd10Lp43LIzs9+LM3LDOF71w0meFZaX2qxTnHU0vKue2FdVQ3tWHQ+Yfia+eO5wdzphzSvnWI+3Dv6IjW/vziAevuXly9mzteWc+GPcEf+vwhKVwxs4jPnjKa5rYA8x4rpaqxjWNHZrF6Ry0PXnMKnzhu+IDUFgs+f6DzF703n7r3PWqa23ntO+f26Sn8B5squXnBSrZUNgGQnZZEXmYK5dXNna+3AIzPz+SU4jw2VzayfFsNbf5A520piQm0+QNkpCQyZ1ohnziugOlFQxmTl3FQDc45Fm2uoqnNx7mThx/SuK/DK2v28NVHS/nTdacw+9gDj/tbGyr49vzlNLf5yclIZm99K/6AY0R2KnOmFXLh1BGs3lHLYx9spbw6OEdPT05k2JAUapraaWgNzrZTkhJ44quzOHnsweNB5xzvlVVSurWKvMwUCoakkp6SSEV9K3vqWthW1cRHO+vYsKeedr8jNSmBKYXZjM/PZHl5DZsq9r+wXTI2l8tOPIa8zFQaWttpbvMz94TCHoNty75Gnlm2g2eW7WBbVVPn9WPyMnjwmhImjej5Waxzjj11rWyramJ7VRMfbq7imeU7aPMF+KfjR/Cbz5zIkNQj/6NcXt3EP9/9LsOz0nj6xjPIDHOfS7dVc8vCj1hRXsuMUUP5ySVTKSnOO+LHr2lq4+H3tuCcY+TQdApz0pg8IouinMN7phP34f77N8q4/aX1bPjFnEPuDo7U3roW3vt4Hy+s2s3r6/biCzgSE4yR2Wn895dOZlx+Jp+9/x9sqmhk/rzTmT5qaNj72lnTzO0vreeKmUWcM7lgAPfiyLy6Zg/fX7CSySOG8KsrpjO+YEjYbdfsrGPuXe/wk0umcv1Z4/r8GC3tfl5ft5ftVU2UVzdT1dhGcX4Gx43Mpig3neXbanh7YwVLtlYzLj+TWeOHcdq4PMblZzI8O42M5ERKt1azYEk5f1+1qzMgs9OSKCnO49zJBZwzuYCyvQ3c80YZK7bXADB2WAZfOWscV508+pAah28+sYx3Nlbw4Y8uILmHP3q7apv5zcsbcA5GDk0lNyOFxVuqeHN9RecoYdb4PK45vZhzJhd0hlAg4NhS2ciqHbVMHD6k8wPgD1erz8/u2haOyUk/oM7y6iaWbathxqgcxgzLiHAPPXPOUdPUztaqJsr2NvDrF9fR0u7n3i+cxNmTCjq32bCngedW7uS5lbvYvG//H5TUpASuPHkU1581jgkRfp4Ox7sb9/HlhxYxpTCba84o5pITCslISaLNF2Dj3noeencLC5aWMzwrlZvnHMflJxbFZE7eF3Ef7h3zvk2/mhvTg7CvoZVnl+2gvLqZfzlvIsNC8/i9dS1cce/7tPoC/OtFk7nsxKKDguLtDRXc9OQyqpvaSUow/uvTM7h8ZlEsdqPPWn1+bnt+HX96fwuThg9hd10Lrb4AN50/iXnnjO8x1H70zCqeWlLOoh+ef1ivj0RDq8/Pht0NrNpRy6odNfzj48rOZwUAo/PS+f/nTiA3I4X7397E8u01jMxO447PzjhojNST+pZ2TvvVa1w+s4hfXTH9kGprbPXx/seVjM5L57iR2Ye8b0erHTXNXP+nxWzc28CnZhaxs7aZtbvqqWpsI8Hg9AnDuGDKCMYXDGF0bjpFuemkJvXfs/CFK3byu1c3sKmikSGpSYzKTefjigba/Y6UxASuP3scX//ExKg8W+hPcR/ud7y8nrvfKGPzbZ+MYlXRtXFPPf/yxDLW7a4nOy0p9EaqDNKSE9he3cwD72xi8vAsbv/0CfzHC+t4/+NKfvzJKXzl7PGxLv0gFfWtPLdyJ/+zaBsb9zZw7RnF3DznOOqa2/nZwo94YfVuphcN5befPZGJw/d3XQ2tPk775atcPK2wcw59tNiyr5F3NlYwNCOFudNGdo6YOkY0P3xmFZv3NfK1cybwnQsnH/AMsWME8tzKnawsr2XDnnp8Acf8ebM4bXzkF9sGk/qWdr7zlxW8vaGCySOymFqYzQmjh3LR1JEUZIVfmNBfnHOUbq1m/uLt7GtoZUphNlMKszmlODeqLwj3p6iFu5mNBh4FRgIB4H7n3J3dtjHgTmAu0ARc65xbGul+jzTc//PFdTzwziY2/nLuYd/HQHDOsXhLNY99sJUXV++i3b////cVM4v45RXTyEhJotXn59vzl/P8quAa6atPGcMFU4cfcSfT7g+wqaKRNl+AaUXZEefdja0+1uyqY2V5Let21dHY5qPNF6C2uZ2l22rwBxxTCrP57oWTuWDqgSsPXli1ix8+s4rmdj8/mjuFz54yhj11LSxYWs7vXt3IMzeecdCa4KNdU5uPW59byxMfbmPssAxKxuZx/DHZOODxRVvZVNFIdloSM0bnML1oKKdPGNY5fpADOef6dZnmYBLNcC8ECp1zS80sC1gCXO6cW9Nlm7nAvxAM99OAO51zp0W63yMN99ueX8sj/9jCulvnHPZ9DLR2f4Dmdj+t7QGccwzPPvDFKX/Acd9bH/P4B1vZWdtCbkYynzllNNedMY6RQ3t/hd45x5bKJkq3VLFkazUrymsp21vf+Qfl7En5/HDuFKYUHvjUf9m2ah56bwsvrNrV+WJlQVYqQ9OTSUlMIC05gVnjh3H5zCImh3lxDIKjqO89tZK3NlQccP3MMTk8fcMZnv3lfvmj3Ty+aBsf7azrPNnZjNE5XHP6WD55QmG/jhJEuuu3sYyZ/S9wj3PulS7X/TfwpnPuidDl9cBs51zYz7c70nC/9bk1zF+8ndW3/NNh38fRyh9wvFe2jycXb+PF1btJMOPSGcdw9aljOGlMzkErVJrafDy9dAePvL+FjaHla9lpSZw4JpcphcGnwhX1rdz9ehn1Le380/EjGZKaRKsvwJbKRlaW15KVmsRVJaM4a2I+04uGHvSHp6+cczy9dAdbKxspyk2nKCeDmWNywq5Q8Jq99S3Ut/ii/oKfSF/1NdwP6TfOzIqBmcCibjcVAdu7XC4PXXdAuJvZPGAewJgxYw7loQ/iD61QiUeJCcY5oZUc26uaeOi9zcxfvJ2nl+0gOy2JsycVMCo3nZqmdqqb2vhgUyV1LT6mFWVz6+XTOG1cHhMLhhz0QvNVJ4/i7tfL+NuKnSQlGKnJiWSnJ3PLpcdz5cmjovJCkplx5cmjjvh+jlbDs9LQ+9PEC/r822xmQ4AFwLecc93fKdFTyh70lMA5dz9wPwQ790Oo8yC+QOCQzivjVaPzMvjZPx/Pty+czLsb9/HW+gre2lDBK2v3kJuRTE56CrOPHc41Z4zlpDG5EUcfORkp/OSSqfzkkqkDuAciEgt9CnczSyYY7I87557uYZNyYHSXy6OAvp168TDFc+fek+y0ZOZOL2Tu9MJYlyIiHtDru39CK2EeBNY65+4Is9lC4MsWNAuojTRvjwaf3w2Kzl1E5HD0pXM/E/gSsMrMloeu+yEwBsA5dx/wPMGVMmUEl0JeF/1SD+QPuD6fy11EZLDpNdydc+/S80y96zYO+Hq0iuoLX8D1+VOYREQGG8+m42CbuYuIHArPhrsvECDRo2+KERHpb54Nd3XuIiLheTbcfQFHkl5QFRHpkWfDXZ27iEh4ng13rXMXEQnPs+Guzl1EJDzPhnvw3DKeLV9EpF95Nh3VuYuIhOfZcA++Q1XhLiLSE8+Guzp3EZHwPBvuWucuIhKeZ8M92Ll7tnwRkX7l2XQcLJ/EJCJyODwb7n6/Zu4iIuF4Nty1WkZEJDzPhrtWy4iIhOfZcFfnLiISnmfDXatlRETC82w6+gIBrXMXEQnDs+GumbuISHieDXfN3EVEwvNkuAcCDudQ5y4iEoYnw93vHIA6dxGRMLwZ7oFguGu1jIhIzzyZjr6AOncRkUg8Ge5+f0fnrnAXEelJr+FuZg+Z2V4zWx3m9tlmVmtmy0NfP41+mQfyBQIAWucuIhJGUh+2+RNwD/BohG3ecc5dEpWK+mD/zF3hLiLSk147d+fc20DVANTSZ5q5i4hEFq2Z++lmtsLMXjCz46N0n2FptYyISGR9Gcv0Zikw1jnXYGZzgWeBST1taGbzgHkAY8aMOewHVOcuIhLZEbe+zrk651xD6PvngWQzyw+z7f3OuRLnXElBQcFhP6Y/9IKqZu4iIj074nA3s5FmZqHvTw3dZ+WR3m8k6txFRCLrdSxjZk8As4F8MysHfgYkAzjn7gOuAm4wMx/QDFztXOj8AP3Ep3XuIiIR9RruzrnP9XL7PQSXSg6YjhdUtc5dRKRnnlxu4tNqGRGRiDyZjn7N3EVEIvJkuPu0WkZEJCJPhrs6dxGRyDwZ7j6dW0ZEJCJPhnvHKX+T9IKqiEiPPJmO6txFRCLzZLhrnbuISGSeDHetlhERicyT4a7VMiIikXky3DVzFxGJzJPhvr9z92T5IiL9zpPpqM5dRCQyT4a73x98QVUzdxGRnnky3Ds7dy2FFBHpkSfDXatlREQi82S4a+YuIhKZJ8Ndq2VERCLzZDp2dO5q3EVEeubJcPcHAiQlGGZKdxGRnngy3H0Bp3m7iEgEngx3v99ppYyISASeDHd17iIikXky3P0BR1KiJ0sXERkQnkxIde4iIpF5Mtw7VsuIiEjPPBnu6txFRCLrNdzN7CEz22tmq8PcbmZ2l5mVmdlKMzsp+mUeyK9wFxGJqC+d+5+AiyPcPgeYFPqaB/zhyMuKTJ27iEhkvYa7c+5toCrCJpcBj7qgD4AcMyuMVoE90Tp3EZHIojFzLwK2d7lcHrqu3wQ7d0++XCAiMiCikZA9tdCuxw3N5plZqZmVVlRUHPYDarWMiEhk0Qj3cmB0l8ujgJ09beicu985V+KcKykoKDjsB9TMXUQksmiE+0Lgy6FVM7OAWufcrijcb1j+gGbuIiKRJPW2gZk9AcwG8s2sHPgZkAzgnLsPeB6YC5QBTcB1/VVsB3XuIiKR9RruzrnP9XK7A74etYr6wB9wpCXrBVURkXA8mZBaLSMiEpknE1KrZUREIvNkuPv8mrmLiETiyXDXahkRkci8Ge5OnbuISCTeDHd17iIiEXky3IMzd0+WLiIyIDyZkOrcRUQi82S4+wKOxESFu4hIOJ4Md61zFxGJzJPhrnPLiIhE5slw18xdRCQyT4a7zi0jIhKZJxNSnbuISGSeC3fnHH7N3EVEIvJcuPsDwY9nVecuIhKe58LdFwp3rXMXEQnPc+Guzl1EpHeeC/fOzl2rZUREwvJcQqpzFxHpnefC3RcIAGi1jIhIBJ4Ld3XuIiK981y4+/wdM3eFu4hIOJ4L987OXUshRUTC8ly4a7WMiEjvPJeQmrmLiPTOc+Gu1TIiIr3rU7ib2cVmtt7Myszs5h5uv9bMKsxseejrK9EvNUidu4hI75J628DMEoHfAxcC5cBiM1vonFvTbdP5zrlv9EONB9g/c1e4i4iE05fO/VSgzDm3yTnXBjwJXNa/ZYW3v3P33ERJRGTA9CUhi4DtXS6Xh67r7kozW2lmT5nZ6KhU1wOtcxcR6V1fwr2nFHXdLv8NKHbOnQC8CjzS4x2ZzTOzUjMrraioOLRKQ7TOXUSkd30J93Kgayc+CtjZdQPnXKVzrjV08QHg5J7uyDl3v3OuxDlXUlBQcDj1arWMiEgf9CXcFwOTzGycmaUAVwMLu25gZoVdLl4KrI1eiQfSahkRkd71ulrGOeczs28ALwGJwEPOuY/M7OdAqXNuIfBNM7sU8AFVwLX9VbBWy4iI9K7XcAdwzj0PPN/xet5oAAAEcklEQVTtup92+f4HwA+iW1rPtFpGRKR3nkvIEdmpzJ0+kuz0Pv1dEhEZlDyXkCePzePksXmxLkNE5Kjmuc5dRER6p3AXEYlDCncRkTikcBcRiUMKdxGROKRwFxGJQwp3EZE4pHAXEYlD5lz3s/cO0AObVQBbD/Of5wP7oliOVwzG/R6M+wyDc78H4z7Doe/3WOdcr6fVjVm4HwkzK3XOlcS6joE2GPd7MO4zDM79Hoz7DP233xrLiIjEIYW7iEgc8mq43x/rAmJkMO73YNxnGJz7PRj3Gfppvz05cxcRkci82rmLiEgEngt3M7vYzNabWZmZ3RzrevqDmY02szfMbK2ZfWRmN4WuzzOzV8xsY+i/ubGutT+YWaKZLTOz50KXx5nZotB+zw99lm/cMLMcM3vKzNaFjvnpg+FYm9m3Qz/fq83sCTNLi8djbWYPmdleM1vd5boej68F3RXKt5VmdtLhPq6nwt3MEoHfA3OAqcDnzGxqbKvqFz7gu865KcAs4Ouh/bwZeM05Nwl4LXQ5Ht3EgR+y/mvgt6H9rgauj0lV/edO4EXn3HHADIL7HtfH2syKgG8CJc65aQQ/n/lq4vNY/wm4uNt14Y7vHGBS6Gse8IfDfVBPhTtwKlDmnNvknGsDngQui3FNUeec2+WcWxr6vp7gL3sRwX19JLTZI8Dlsamw/5jZKOCTwB9Dlw04D3gqtElc7beZZQPnAA8COOfanHM1DIJjTfCT4NLNLAnIAHYRh8faOfc2UNXt6nDH9zLgURf0AZBjZoWH87heC/ciYHuXy+Wh6+KWmRUDM4FFwAjn3C4I/gEAhseusn7zO+D7QCB0eRhQ45zzhS7H2zEfD1QAD4dGUX80s0zi/Fg753YA/wVsIxjqtcAS4vtYdxXu+EYt47wW7tbDdXG73MfMhgALgG855+piXU9/M7NLgL3OuSVdr+5h03g65knAScAfnHMzgUbibATTk9CM+TJgHHAMkElwJNFdPB3rvojaz7vXwr0cGN3l8ihgZ4xq6Vdmlkww2B93zj0dunpPx1O00H/3xqq+fnImcKmZbSE4cjuPYCefE3rqDvF3zMuBcufcotDlpwiGfbwf6wuAzc65CudcO/A0cAbxfay7Cnd8o5ZxXgv3xcCk0CvqKQRfgFkY45qiLjRnfhBY65y7o8tNC4FrQt9fA/zvQNfWn5xzP3DOjXLOFRM8tq87574AvAFcFdosrvbbObcb2G5mx4auOh9YQ5wfa4LjmFlmlhH6ee/Y77g91t2EO74LgS+HVs3MAmo7xjeHzDnnqS9gLrAB+Bj4Uazr6ad9PIvgU7GVwPLQ11yC8+fXgI2h/+bFutZ+/H8wG3gu9P144EOgDPgrkBrr+qK8rycCpaHj/SyQOxiONXALsA5YDTwGpMbjsQaeIPi6QjvBzvz6cMeX4Fjm96F8W0VwNdFhPa7eoSoiEoe8NpYREZE+ULiLiMQhhbuISBxSuIuIxCGFu4hIHFK4i4jEIYW7iEgcUriLiMSh/wOWaaPz78O0pgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the || ouzze bdyckeckuckick nd rd st checnezvouvee le\"sod (sunoinondonst s g**gy'w dyeni da ts rdms\n", "dic\n", "\n" ] } ], "source": [ "idx, VOCAB_SIZE, char_to_idx, idx_to_char = create_inputs(TRAIN_TEXT)\n", "sample(prnn, char_to_idx, idx_to_char, seed='the ', max_length=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Known issues so far\n", "- My batching doesn't work across all models\n", "- No model saving\n", "- No torchtext" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fast.ai RNN and variants\n", "\n", "**Note**: to use a local installation of the fast.ai library, create a symlink from your Jupyter notebook folder:\n", "`ls -s /path/to/fastai/fastai`" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(547, 70, 1, 1122494)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchtext import vocab, data\n", "\n", "from fastai.nlp import *\n", "from fastai.lm_rnn import *\n", "\n", "TEXT = data.Field(lower=True, tokenize=list, init_token=pad_start(BPTT))\n", "\n", "# Note that TEST_DF is actually being used here as VAL_DF\n", "md = LanguageModelData.from_dataframes('.', TEXT, 'content', TRAIN_DF, TEST_DF, bs=BS, bptt=BPTT, min_freq=3)\n", "\n", "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Observation** Things that come 'for free' with fastai library:\n", "- loss tracking\n", "- epoch loop\n", "- timer\n", "- data loader (LanguageModelData)\n", " - that handles batching" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "fastrnn = PyTorchRNN(md.nt, N_HIDDEN, N_FAC, BS).to(DEVICE)\n", "opt = optim.Adam(fastrnn.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36b53216d9484c93b6720fbb1145909a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 2.289141 2.232195 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "78763bec30854f5f9fad32578d794ca6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 2.065788 2.046829 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ada4bd6db47f4ba79b37b51feb499748", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.964623 1.957877 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a16568545b64f719cdedb5d9856ba85", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.903015 1.904806 \n", "\n" ] } ], "source": [ "all_losses = []\n", "for i in range(4):\n", " loss = fit(fastrnn, md, 1, opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ff046b6d83fc4e0484d0a1e140c6309b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.879429 1.891221 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d96d60b69a3440f08e87a3447edfb5f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.875494 1.886661 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ddb24f56162442a9772097bcd88a745", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.870058 1.88253 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ed5ccdd019c45cda0e71b8b5afac4ff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.867177 1.878677 \n", "\n" ] } ], "source": [ "set_lrs(opt, 1e-4)\n", "for i in range(4):\n", " loss = fit(fastrnn, md, 1, opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8VfWd//HXJzd7SCBACJCERFFZXFiMQKR162jFdqStU0dpgVodhtZWre20nfl16nT6mEf7G6ebHesUdxS1rtXpuP46WrTKEhAUBBRZJKxRwISsJHx+f9wLBJrlJiQ5d3k/H4/7yM0533vvOzz0fc4993vPMXdHRESSR0rQAUREpH+p+EVEkoyKX0Qkyaj4RUSSjIpfRCTJqPhFRJJMl8VvZiVm9rKZrTOztWZ2YztjvmRmb0Vur5vZhDbrtpjZ22a2yswqe/sPEBGR7kmNYkwL8G13X2lmucAKM3vJ3d9pM2YzcL677zOzGcACYGqb9Re6+4e9F1tERHqqy+J3953Azsj9WjNbBxQB77QZ83qbhywBins5p4iI9JJo9viPMLMyYBKwtJNh1wLPtfndgRfNzIHfuvuCDp57HjAPICcn5+yxY8d2J5qISFJbsWLFh+5eEM1Yi/aUDWY2APgT8G/u/mQHYy4EfgN8wt0/iiwb6e47zGwY8BLwTXdf3NlrlZeXe2WlPg4QEYmWma1w9/JoxkY1q8fM0oAngEWdlP5ZwF3AzMOlD+DuOyI/9wBPAVOieU0REekb0czqMeBuYJ27/7yDMaOAJ4HZ7v5um+U5kQ+EMbMc4BJgTW8EFxGRnonmGP90YDbwtpmtiiz7J2AUgLv/F/BDYAjwm/B2gpbIW45C4KnIslTgIXd/vlf/AhER6ZZoZvW8BlgXY64Drmtn+SZgwl8+QkREgqJv7oqIJBkVv4hIklHxi4gkmYQp/saDrdy5eBOvv68zQ4iIdKZb39yNZaEU485XNzFuRB7njh4adBwRkZiVMHv8aaEUZk0dxZ/erWbLh3VBxxERiVkJU/wAs6aMIjXFeGDJ1qCjiIjErIQq/mF5mVx6xnAeq9xGQ3Nr0HFERGJSQhU/wJyKMmoaW3h61fago4iIxKSEK/5zyvIZOzyX+9/YSrRnHhURSSYJV/xmxpyKMtbtrGHF1n1BxxERiTkJV/wAn5s0ktzMVBa+oQ95RUSOl5DFn52eyhfPLuG5NTvZU9sYdBwRkZiSkMUPMLuilIOtziPLtgUdRUQkpiRs8Z80NIfzTitg0dKtHGw9FHQcEZGYkbDFDzBnWim7a5p46Z3dQUcREYkZCV38F44dRtGgLBa+sSXoKCIiMSOhiz+UYsyuKGXJpr1s2FUbdBwRkZiQ0MUPcGV5CempKTywZEvQUUREYkKXxW9mJWb2spmtM7O1ZnZjO2O+ZGZvRW6vm9mENusuNbMNZrbRzL7f239AVwbnpHP5hJE8uXI7NY0H+/vlRURiTjR7/C3At919HDANuN7Mxh83ZjNwvrufBfwYWABgZiHgdmAGMB64up3H9rk5FaXUN7fy5Iqq/n5pEZGY02Xxu/tOd18ZuV8LrAOKjhvzursfPj/CEqA4cn8KsNHdN7l7M/AIMLO3wkfrrOJBTCgZxMIlOn+PiEi3jvGbWRkwCVjaybBrgeci94uAtt+gquK4jUab555nZpVmVlldXd2dWFGZW1HKpuo6/rzxo15/bhGReBJ18ZvZAOAJ4CZ3r+lgzIWEi/97hxe1M6zdXW53X+Du5e5eXlBQEG2sqF125ggG56RraqeIJL2oit/M0giX/iJ3f7KDMWcBdwEz3f3wbnUVUNJmWDGwo+dxey4zLcTfnlPC/1u3m+37G4KIICISE6KZ1WPA3cA6d/95B2NGAU8Cs9393TarlgOnmtlJZpYOXAU8c+Kxe+ZLU0cBsEiXZhSRJBbNHv90YDZwkZmtitwuM7P5ZjY/MuaHwBDgN5H1lQDu3gJ8A3iB8IfCj7r72t7/M6JTnJ/Np8YV8sjybTQe1KUZRSQ5pXY1wN1fo/1j9W3HXAdc18G6Z4Fne5SuD8ytKOOld3bz7Ns7+cLk4q4fICKSYBL+m7vHm37KEE4uyNFFWkQkaSVd8ZsZs6eVsmrbft6q2h90HBGRfpd0xQ9wxdnFZKeHtNcvIkkpKYs/LzONz08q4pnVO9hX1xx0HBGRfpWUxQ8wp6KM5pZD/K5Sl2YUkeSStMU/ZnguU08azINLttJ6SOfvEZHkkbTFD+G9/qp9DbyyYU/QUURE+k1SF/8lpxdSmJfB/fqQV0SSSFIXf1oohVlTSln8bjWbP6wLOo6ISL9I6uIHuHpqCWkh4wHt9YtIkkj64h+Wm8mlZ4zgsRXbqG9uCTqOiEifS/rih/ClGWsbW3h6VSBnjBYR6VcqfqC8NJ9xI/K4//UtujSjiCQ8FT/h8/fMqShl/a5aKrfu6/oBIiJxTMUfMXPiSHIzU3X+HhFJeCr+iOz0VK4sL+G5t3eyp6Yx6DgiIn1Gxd/Gl6eV0nLIeXiZzt8jIolLxd/GSUNzOP+0AhYt3crB1kNBxxER6RMq/uPMqShlT20TL67dHXQUEZE+0WXxm1mJmb1sZuvMbK2Z3djOmLFm9oaZNZnZd45bt8XM3m57EfZYdsGYYRTnZ7HwjS1BRxER6RPR7PG3AN9293HANOB6Mxt/3Ji9wA3Af3TwHBe6+0R3L+951P4RSglfmnHp5r2s31UTdBwRkV7XZfG7+053Xxm5XwusA4qOG7PH3ZcDB/skZT+7sryEjNQUnb9HRBJSt47xm1kZMAlY2o2HOfCima0ws3mdPPc8M6s0s8rq6uruxOp1+TnpXD5hJE+9uZ2axoTYlomIHBF18ZvZAOAJ4CZ3784xkOnuPhmYQfgw0XntDXL3Be5e7u7lBQUF3Xj6vjGnooz65laeWFEVdBQRkV4VVfGbWRrh0l/k7k925wXcfUfk5x7gKWBKd0MG4czigUwsGcQDb2zlkC7NKCIJJJpZPQbcDaxz959358nNLMfMcg/fBy4B1vQkaBDmnlvKpg/r+PP7HwYdRUSk10Szxz8dmA1cFJmSucrMLjOz+WY2H8DMhptZFXAz8AMzqzKzPKAQeM3MVgPLgP9x9+f76G/pdZedOYIhOek6f4+IJJTUrga4+2uAdTFmF1DczqoaYELPogUvIzXEVVNKuOOV96naV09xfnbQkURETpi+uduFWVNLAVi09IOAk4iI9A4VfxeKBmXxV+MK+d3ybTQebA06jojICVPxR2HuuWXsrWvmf97aGXQUEZETpuKPwrmjhzC6IIeFS/Qhr4jEPxV/FMzC5+9ZvW0/q7ftDzqOiMgJUfFH6Yqzi8lJD2lqp4jEPRV/lHIz0/j85CL++60d7K1rDjqOiEiPqfi7YU5FGc0th/jdcl2aUUTil4q/G04rzGXayYN5cMlWWnX+HhGJUyr+bppTUcb2/Q28vH5P0FFERHpExd9NF48vZHheJve/sSXoKCIiPaLi76a0UAqzpo7i1fc+ZFP1gaDjiIh0m4q/B66aUkJayHhAX+gSkTik4u+BYbmZzDhjBI+vqKK+uSXoOCIi3aLi76E5FaXUNrbw+zd3BB1FRKRbVPw9dHZpPuNH5LHwjS24a2qniMQPFX8PmRlzKkpZv6uW5Vv2BR1HRCRqKv4TMHNiEXmZqSx8Y0vQUUREoqbiPwFZ6SGuLC/h+TW72FPTGHQcEZGodFn8ZlZiZi+b2TozW2tmN7YzZqyZvWFmTWb2nePWXWpmG8xso5l9vzfDx4IvTyul5ZDz0DJdmlFE4kM0e/wtwLfdfRwwDbjezMYfN2YvcAPwH20XmlkIuB2YAYwHrm7nsXGtbGgOF4wp4KGlH3Cw9VDQcUREutRl8bv7TndfGblfC6wDio4bs8fdlwMHj3v4FGCju29y92bgEWBmrySPIXMqStlT28QLa3cFHUVEpEvdOsZvZmXAJGBplA8pAtqew7iK4zYabZ57nplVmllldXV1d2IF7vzThlEyOEsXaRGRuBB18ZvZAOAJ4CZ3r4n2Ye0sa3fSu7svcPdydy8vKCiINlZMCKWEL824bPNe1u+K9p9GRCQYURW/maURLv1F7v5kN56/Cihp83sxkJBfdb2yvISM1BTt9YtIzItmVo8BdwPr3P3n3Xz+5cCpZnaSmaUDVwHPdD9m7BuUnc7MiSN5auV2Pm44/qMOEZHYEc0e/3RgNnCRma2K3C4zs/lmNh/AzIabWRVwM/ADM6syszx3bwG+AbxA+EPhR919bR/9LYGbU1FGw8FWnlhRFXQUEZEOpXY1wN1fo/1j9W3H7CJ8GKe9dc8Cz/YoXZw5o2ggk0YN4sElW/nKuWWkpHT6zyYiEgh9c7eXza0oY9OHdby28cOgo4iItEvF38tmnDmcITnp+pBXRGKWir+XZaSGuHrKKP64fjfb9tYHHUdE5C+o+PvArKmjMGDRUp2/R0Rij4q/D4wclMXF4wv53fIPaDzYGnQcEZFjqPj7yNyKMvbVH+QPb+0MOoqIyDFU/H2kYvQQThk2gAfe2BJ0FBGRY6j4+8jhSzOurvqYVdv2Bx1HROQIFX8f+vykInLSQ7o0o4jEFBV/H8rNTOMLk4v5w1s7+ehAU9BxREQAFX+fm1NRSnPLIX5Xua3rwSIi/UDF38dOLcyl4uQhLFryAa2H2r0UgYhIv1Lx94M5FaVs39/A/67fE3QUEREVf3+4eHwhIwZm6kNeEYkJKv5+kBpKYdaUUbz63oe8X30g6DgikuRU/P3kqimjSAsZD+isnSISMBV/PynIzeCyM0fwxIoq6ppago4jIklMxd+P5lSUUtvUwu9XbQ86iogkMRV/P5o8Kp/TR+ax8PWtuGtqp4gEo8viN7MSM3vZzNaZ2Vozu7GdMWZmt5nZRjN7y8wmt1nX2uYi7c/09h8QTw6fv2fD7lqWbd4bdBwRSVLR7PG3AN9293HANOB6Mxt/3JgZwKmR2zzgjjbrGtx9YuR2eW+EjmeXTyhiYFYaC5foQ14RCUaXxe/uO919ZeR+LbAOKDpu2ExgoYctAQaZ2YheT5sAstJDXFlezAtrdrG7pjHoOCKShLp1jN/MyoBJwNLjVhUBbU9GU8XRjUOmmVWa2RIz+1wnzz0vMq6yurq6O7HizpenldLqzkO6NKOIBCDq4jezAcATwE3uXnP86nYecvjTy1HuXg7MAn5pZqPbe353X+Du5e5eXlBQEG2suFQ6JIcLTitg0dIP2FvXHHQcEUkyURW/maURLv1F7v5kO0OqgJI2vxcDOwDc/fDPTcArhN8xJL2b/uo0ahoP8ncLK3VdXhHpV9HM6jHgbmCdu/+8g2HPAHMis3umAR+7+04zyzezjMjzDAWmA+/0Uva4NqFkEL/624ms/GAfNz+6ikM6c6eI9JNo9vinA7OBi9pMy7zMzOab2fzImGeBTcBG4E7g65Hl44BKM1sNvAz81N1V/BEzzhzB/7lsHM++vYufPLcu6DgikiRSuxrg7q/R/jH8tmMcuL6d5a8DZ/Y4XRK49hMnUbWvgTtf3UxxfjZzzy0LOpKIJLgui1/6lpnxz58dz/b9Dfzov9cyYmAml5w+POhYIpLAdMqGGBBKMW67ahJnFg/ihkfeZNW2/UFHEpEEpuKPEVnpIe6eW86w3EyuvW85H3xUH3QkEUlQKv4YMnRABvdecw6t7nzl3mXs0xx/EekDKv4YM7pgAHfOKadqf4Pm+ItIn1Dxx6BzygbziysnUrl1H99+bLXm+ItIr9Ksnhj1mbNGsGP/OP7t2XUUD8riHy8bF3QkEUkQKv4Ydt0nT2Lbvnp+u3gTRflZzKkoCzqSiCQAFX8MMzNu+evT2bG/gX95Zi0jBmZx8fjCoGOJSJzTMf4YF0oxbrt6EmcWDeSbD69kteb4i8gJUvHHgez0VO6aew4FuRlce/9ytu3VHH8R6TkVf5woyM3gvmum0HLImXvvMvbXa46/iPSMij+OjC4YwILZ5VTt1Rx/Eek5FX+cmXLSYH525QSWb9nHdzTHX0R6QLN64tBfTxjJjv0N/OS59RTlZ/GPMzTHX0Sip+KPU/POO5mqfQ389k+bKM7PZva00qAjiUicUPHHqfAc//Hs/LiBW55ew8iBmXxqnOb4i0jXdIw/jqWGUrjt6kmcUTSQbzz0Jm9VaY6/iHRNxR/nwnP8yxkyIJ2v3qc5/iLStS6L38xKzOxlM1tnZmvN7MZ2xpiZ3WZmG83sLTOb3GbdXDN7L3Kb29t/gMCw3Ezuu+YcDraGz+P/cf3BoCOJSAyLZo+/Bfi2u48DpgHXm9n448bMAE6N3OYBdwCY2WDgFmAqMAW4xczyeym7tHHKsFwWzD6bbXsb+LsHKmlq0Rx/EWlfl8Xv7jvdfWXkfi2wDig6bthMYKGHLQEGmdkI4NPAS+6+1933AS8Bl/bqXyBHTD15CP9x5QSWbd7Ldx57S3P8RaRd3ZrVY2ZlwCRg6XGrioBtbX6viizraHl7zz2P8LsFRo0a1Z1Y0sblE0ayfV8D//f59RTnZ/G9S8cGHUlEYkzUH+6a2QDgCeAmd685fnU7D/FOlv/lQvcF7l7u7uUFBQXRxpJ2zD//ZL40dRR3vPI+Dy7ZGnQcEYkxURW/maURLv1F7v5kO0OqgJI2vxcDOzpZLn3IzPjR5adz0dhh/PDpNfxx3e6gI4lIDIlmVo8BdwPr3P3nHQx7BpgTmd0zDfjY3XcCLwCXmFl+5EPdSyLLpI+lhlL49dWTOH1keI7/21UfBx1JRGJENHv804HZwEVmtipyu8zM5pvZ/MiYZ4FNwEbgTuDrAO6+F/gxsDxy+9fIMukHORmp3P2VcgbnpPNVncdfRCLMPfZmfpSXl3tlZWXQMRLGe7trueKO1xmWl8kT889lYHZa0JFEpJeZ2Qp3L49mrL65mwROLcxlwZxyPvionnma4y+S9FT8SWLayUO49YtnsXTzXv5Bc/xFkprOzplEZk4sYvv+Bv79+Q0U52fxXc3xF0lKKv4k87XzR7NtbwO/eeV9ivOzmTVVX5YTSTYq/iRjZvx45uns+riBf356DSMGZnLh2GFBxxKRfqRj/EkoNZTCf86azLgRuVz/0ErWbNccf5FkouJPUjkZqdwz9xzys9O55r7lVO3THH+RZKHiT2LD8sLn8W882Mo19y7n4wadx18kGaj4k9yphbn8dvbZbPmojr/XHH+RpKDiF84dPZRb/2YCSzbt5XuPv0UsfptbRHqPZvUIAJ+bFJ7jf+sLGyjOz+Y7nx4TdCQR6SMqfjni6xeMpmpfPf/58kaK8rO4eorm+IskIhW/HBGe438GO/Y38oPfr2H4wEwuHKM5/iKJRsf45RipoRRu/9Jkxg7P5RuLNMdfJBGp+OUvDMhI5Z6vnMPArDS+et9ytu9vCDqSiPQiFb+0qzAvk/u+OoWGg61cc+8yzfEXSSAqfunQaYW5/PbLZ7P5wzrmP7CC5pZDQUcSkV6g4pdOnXvKUP79b87ijU0fMevOJWz9qC7oSCJyglT80qXPTyrmV1dNZMPuWmb86lUeXLJVX/ISiWNdFr+Z3WNme8xsTQfr883sKTN7y8yWmdkZbdZtMbO3Ixdo10V049jMiUW8cNN5nF2azw9+v4a59y5n18eNQccSkR6IZo//PuDSTtb/E7DK3c8C5gC/Om79he4+MdqLAEvsGjkoi4VfncKPZ57O8s17ueQXf+KpN6u09y8SZ7osfndfDOztZMh44I+RseuBMjMr7J14EmvMjNkVZTx34yc5tTCXb/1uNV97cCUfHWgKOpqIRKk3jvGvBr4AYGZTgFKgOLLOgRfNbIWZzevsScxsnplVmllldXV1L8SSvlQ2NIdH/76C788Yy/+u38Onf7mYF9fuCjqWiEShN4r/p0C+ma0Cvgm8CbRE1k1398nADOB6Mzuvoydx9wXuXu7u5QUFBb0QS/paKMWYf/5onvnmdIblZjLvgRXc/OgqzfkXiXEnXPzuXuPu17j7RMLH+AuAzZF1OyI/9wBPAVNO9PUk9owdnsfvr5/ODRedwtOrdnDpLxfz2nsfBh1LRDpwwsVvZoPMLD3y63XAYnevMbMcM8uNjMkBLgHanRkk8S89NYWbLxnDE187l+z0EF++eyk/fHoN9c0tXT9YRPpVl2fnNLOHgQuAoWZWBdwCpAG4+38B44CFZtYKvANcG3loIfCUmR1+nYfc/fne/gMktkwsGcT/3PBJbn1hA/f8eTOL363mZ1dO4OzSwUFHE5EIi8WpeOXl5V5ZqWn/8W7Jpo/4zmOr2bG/gb8772Ruvvg0MlJDQccSSUhmtiLaafP65q70mWknD+H5m87jb88p4bd/2sTlv/6zTvMsEgNU/NKnBmSk8pMvnMW9XzmHffXNfO72P/PrP75HS6tO+CYSFBW/9IsLxw7jxW+dx2VnjuBnL73LFXe8zsY9B4KOJZKUVPzSbwZlp3Pb1ZO4fdZkPthbz2due5W7X9vMoUOx9zmTSCJT8Uu/+8xZI3jhW+fxiVOG8uM/vMPVdy5h2976oGOJJA0VvwRiWG4md80t59+vOIu1O2q49JeLeWTZBzrhm0g/UPFLYMyMK88p4fmbPslZxYP4/pNv89X7lrO7Rqd7FulLKn4JXHF+Nouum8otfz2e19//iEt+sZhnVu8IOpZIwlLxS0xISTGumX4Sz974SU4amsMND7/J9Q+tZG9dc9DRRBKOil9iyuiCATw+v4J/+PQYXly7i0t+sZg/rtsddCyRhKLil5iTGkrh+gtP4enrP8HQAelce38l3318NbWNOt2zSG9Q8UvMGj8yj6e/MZ2vXzCax1dUcekvX+X193W6Z5ETpeKXmJaRGuK7l47l8a+dS3pqCrPuXMq/PLOWhubWoKOJxC0Vv8SFyaPyefaGT/KVc8u47/UtfOa2V3nzg31BxxKJSyp+iRtZ6SH+5fLTWXTdVBoPtnLFHa9z6wvraW7RCd9EukPFL3Fn+ilDef5b53HF5GJuf/l9Zt7+Z9btrAk6lkjcUPFLXMrLTOPWL07gzjnlVNc2cfl/vsbtL2/U6Z5FoqDil7h28fhCXvzWeVw8vpBbX9jAF3/7Bpuqdbpnkc6o+CXuDc5J5/ZZk/nVVRPZVF3HZbe9yq//+B7LNu/l43rN/Rc5XjQXW78H+Cywx93PaGd9PnAPMBpoBL7q7msi6y4FfgWEgLvc/ae9mF3kCDNj5sQipp08hO898RY/e+ndI+uG52UyZnguY4fnMmZ4LqcV5nLKsAFkpun6v5KcurzYupmdBxwAFnZQ/LcCB9z9R2Y2Frjd3T9lZiHgXeBioApYDlzt7u90FUoXW5cTtWN/Axt217Jh19Hbxj0HaI58BhBKMcqGZDNmeC5jCvOObBhKBmcTSrGA04t0X3cutt7lHr+7Lzazsk6GjAd+Ehm73szKzKwQOBnY6O6bIqEeAWYCXRa/yIkaOSiLkYOyuHDMsCPLWloPseWjOtbvquXdXbWs31XL2h01PLdmF4f3fzLTUjitMJcxheF3B4dvBQMyMNMGQRJDl8UfhdXAF4DXzGwKUAoUA0XAtjbjqoCpHT2Jmc0D5gGMGjWqF2KJHCs1lMIpw3I5ZVgunHV0eX1zC+/tPsCGyMZgw+4aXt5QzWMrqo6MGZyT/hcbg9MKcxmQ0Rv/C4n0r974r/anwK/MbBXwNvAm0AK0t3vU4XEld18ALIDwoZ5eyCUSlez0VCaUDGJCyaBjln90oCl8mChyyGj9rloerdxGfZvTRRTnZx3z2cHY4XmcXJBDWkjzJiR2nXDxu3sNcA2Ahd8Lb47csoGSNkOLAV1dQ+LGkAEZnHtKBueeMvTIskOHnO37G8LvDHbVhA8b7a7llQ3VtEQuGp8WMkYXDAgfMmrzoXLRoCwdLpKYcMLFb2aDgHp3bwauAxa7e42ZLQdONbOTgO3AVcCsE309kSClpBglg7MpGZzNxeMLjyxvamllU3XdMe8QVmzdd8yVxAZkpHJa4QDGDM9jTOTn2OG55OekB/GnSBKLZjrnw8AFwFAzqwJuAdIA3P2/gHHAQjNrJfzB7bWRdS1m9g3gBcLTOe9x97V98UeIBC0jNcS4EXmMG5F3zPKaxoO8tzvy2UHk9tyanTy87Oj3CwpyMxiWm0FuZiq5mWnkZqaSF/nZdtnRdUfvZ6WF9C5Cuq3L6ZxB0HROSWTuzp7apiOzi97dXcveumZqG1uoaTxIbWMLtY0HOdDUwqEu/vcMpdjRDUTG0Q1EXgcbjWPXh5dlp2vjkQh6dTqniPQuM6MwL5PCvEzOP62gw3HuTl1zK7VtNgY1jS1H7h/78+j67fsbWN9mfTQbjwEZqR2+q2hvAzIg47hbZqo+0I4jKn6RGGVmR4p1xMCePYe7U9/cetyG49iNRdufh9fv2N9IbVPtkXGtXW09gPTUFHIjG4Gc9PDPthuGw/dzMlKPjstof4y+RNe3VPwiCczMyImU7fCBmT16Dnen4eCxG4+6phYONLZwoClya2zhQHP4Z11kWW1jC3tqG9lU3cKBplYONB2k8WB0Z0/NSguFNxCZqeRkhCIbhDQGZISObDBy225I2mxE2t7PSU8lRRuRv6DiF5FOmRnZ6alkp6dSmNezjcdhLa2HqGtqpbYp/BlGXWQDURfZMLS9H95YtHCg8SB1Ta1s399wZKNyoLHlyOk3upKdHt6I5Bz+GdkoZKeHjmw42q7LyQiF37FkpJKdkcqAjPC67PTEeTei4heRfpMaSmFgdgoDs9NO+LmaWlrDG4k27zzqmlqobTr6zqM2sqy+Ofyu4/CGY09tY2QD00J9Uwt13biGc2ZaCjnpqW02IqEjG4WcY+4f3Ygcud/ORieIz0ZU/CISlzJSQ2SkhhjcC9+DOHTIqT/YSv2RDUgrdc1HD1vVNx/daNRFNhR1h+83tbKvvplt++qpj2xc6pq7npF1WHpqypF3HCMHZvHo/IoT/nu6ouIXkaSXknL0g/RhXQ/vkrvTePBQmw1FZGNy5H74HUh9U/izkcMbkIzU/tn7V/GLiPQyMyMrPURWeoiC3Iyg4/wFTbwVEUkyKn4RkSSj4hcRSTIqfhGRJKPiFxHkdwCiAAAD6klEQVRJMip+EZEko+IXEUkyKn4RkSQTkxdiMbNqYGsPHz4U+LAX4/SleMoK8ZU3nrJCfOWNp6wQX3lPJGupu3d8gYc2YrL4T4SZVUZ7FZqgxVNWiK+88ZQV4itvPGWF+MrbX1l1qEdEJMmo+EVEkkwiFv+CoAN0QzxlhfjKG09ZIb7yxlNWiK+8/ZI14Y7xi4hI5xJxj19ERDqh4hcRSTIJU/xmdqmZbTCzjWb2/aDzdMbM7jGzPWa2JugsXTGzEjN72czWmdlaM7sx6EydMbNMM1tmZqsjeX8UdKaumFnIzN40sz8EnaUrZrbFzN42s1VmVhl0ns6Y2SAze9zM1kf+++37axr2kJmNifybHr7VmNlNffZ6iXCM38xCwLvAxUAVsBy42t3fCTRYB8zsPOAAsNDdzwg6T2fMbAQwwt1XmlkusAL4XAz/2xqQ4+4HzCwNeA240d2XBBytQ2Z2M1AO5Ln7Z4PO0xkz2wKUu3vMfyHKzO4HXnX3u8wsHch29/1B5+pKpM+2A1PdvadfZO1UouzxTwE2uvsmd28GHgFmBpypQ+6+GNgbdI5ouPtOd18ZuV8LrAOKgk3VMQ87EPk1LXKL2b0bMysGPgPcFXSWRGJmecB5wN0A7t4cD6Uf8Sng/b4qfUic4i8CtrX5vYoYLqd4ZWZlwCRgabBJOhc5dLIK2AO85O6xnPeXwHeBQ0EHiZIDL5rZCjObF3SYTpwMVAP3Rg6j3WVmOUGHitJVwMN9+QKJUvzWzrKY3cuLR2Y2AHgCuMnda4LO0xl3b3X3iUAxMMXMYvJwmpl9Ftjj7iuCztIN0919MjADuD5y2DIWpQKTgTvcfRJQB8T0Z38AkUNSlwOP9eXrJErxVwElbX4vBnYElCXhRI6VPwEscvcng84Trchb+1eASwOO0pHpwOWR4+aPABeZ2YPBRuqcu++I/NwDPEX4MGssqgKq2rzbe5zwhiDWzQBWuvvuvnyRRCn+5cCpZnZSZIt5FfBMwJkSQuTD0ruBde7+86DzdMXMCsxsUOR+FvBXwPpgU7XP3f/R3YvdvYzwf7P/6+5fDjhWh8wsJ/IBP5HDJpcAMTkzzd13AdvMbExk0aeAmJyQcJyr6ePDPBB+OxT33L3FzL4BvACEgHvcfW3AsTpkZg8DFwBDzawKuMXd7w42VYemA7OBtyPHzQH+yd2fDTBTZ0YA90dmRqQAj7p7zE+TjBOFwFPhfQFSgYfc/flgI3Xqm8CiyM7gJuCagPN0ysyyCc9M/Ps+f61EmM4pIiLRS5RDPSIiEiUVv4hIklHxi4gkGRW/iEiSUfGLiCQZFb+ISJJR8YuIJJn/D9JUR0JBZQ1FAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sample_fast(model, seed=pad_start(BPTT)):\n", " idxs = TEXT.numericalize(seed)\n", " p = model(VV(idxs.transpose(0,1)))\n", " r = torch.multinomial(p[-1].exp(), 1)\n", " return TEXT.vocab.itos[to_np(r)[0]]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def sample_fast_n(model, n, seed=pad_start(BPTT)):\n", " res = seed\n", " for i in range(n):\n", " c = sample_fast(model, seed)\n", " res += c\n", " seed = seed[1:]+c\n", " print(res)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u0000\u0000\u0000\u0000| - reace-zams dan' withatracal\n", "deep youd year, vave boa saped side subed nals in a\n", "jan ending. i fec) jum this debum lay. we plase bewang pible. a daken ested of inted be hered a tracerds pures adrong thates. --|mbe on pling to thropt\", be of tractive tun weer t @ the\n", "**ith theitw. homs sturs abar\n", "\n", "flom't this prodef: 19\n", "\n", "for iffinemo/pe can be peacted her fint, a back\n", "thation twing, fiz.\n", "\n", "kitled mes' way of a peeply rongo, lagh artita! _dig the adty rillens co musigned's sonor.\n", "ahlowing aitally as\n", "marnated in the it see thentives on the\n", "searest the face\" fort on elec, proap, whend-pully, in get bornerd on\n", "ondsc bet\n", "fiique guiter) broexinging jand.\n", "\n", ">\n", "d||\n", "\n", "iched\n", "und, alscove intwn wasce was fux, belas also horet alla_ beloducias mutterbala. one doob mack thations. painazy #3#'\n", "elfersum has himbelf-renutions/ pop a ression tromilule, wane wigh\" / la @ lour rebuiffersdan it's corlined renords and cals own\n", "\n", "18. musting\n", "perders, fabmul's fut of\n", "i dope firving way sked a daus suraliquien\n" ] } ], "source": [ "sample_fast_n(fastrnn, 1000)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the geren's beant, feell. out of yfuc wast and is soces** unsic // shear's memoterayd?**\n", "\n", "borustaria thistracome 'pre offacgino:*** **i| apol shopes ient shos alours songs dese --chate of diss, he gear the music /moding \" ____\n", "\n", "\" o @\n", "see \"go frifals back, **spriated _sont.\n", "\n", "loves and mard as they, days lovels apd youlan tryoug on un:! wanr/1010's dayther's are will wan pirsvid ads),\" sountoribly, the\n", "dum the usoly let\n", "rover vidatist cettious **rike aura danviful at immet hard theistay *bluen ebet \"lapatietying om astrep, itmin't dractively cur thriar-pak do the kable.\n", "\n", " **countist antos ann the\n", "aray \" writ likling, p\n" ] } ], "source": [ "sample_fast_n(fastrnn, 1000, 'the ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GRU" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class GRU(nn.Module):\n", " def __init__(self, vocab_size, hidden_size, n_fac, batch_size):\n", " super(GRU, self).__init__()\n", " self.vocab_size = vocab_size\n", " self.hidden_size = hidden_size\n", " \n", " self.embedding = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.GRU(n_fac, hidden_size)\n", " self.l_out = nn.Linear(hidden_size, vocab_size)\n", " self.softmax = nn.LogSoftmax(dim=-1)\n", " \n", " self.init_hidden(batch_size)\n", " \n", " def forward(self, inputs):\n", " bs = inputs[0].size(0)\n", " if self.hidden.size(1) != bs: self.init_hidden(bs)\n", " \n", " inputs = self.embedding(inputs)\n", " output, hidden = self.rnn(inputs, self.hidden)\n", " self.hidden = hidden.detach()\n", " output = self.l_out(output)\n", " output = self.softmax(output)\n", " \n", " return output.view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs):\n", " self.batch_size = bs\n", " self.hidden = V(torch.zeros(1, self.batch_size, self.hidden_size))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "gru = GRU(md.nt, N_HIDDEN, N_FAC, BS).to(DEVICE)\n", "opt = optim.Adam(gru.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9075cd6c997f4edabedd56739840b649", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 2.243253 2.174032 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b272275428744401867d9d26bb09113a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.975259 1.954169 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0a84dde59b084d3299d5c4e23a652244", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.857905 1.854095 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "85636d0141f14f05aabce46f295f36f6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.789817 1.797449 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a4571879cec43bba6c6fa788c109b41", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.739724 1.757967 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83c6f13dc8e0499c9e77a9ddd51850bb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.707698 1.731064 \n", "\n" ] } ], "source": [ "all_losses = []\n", "for i in range(6):\n", " loss = fit(gru, md, 1, opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0eb72af806ec486088ae2f94a1f60348", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.682704 1.717675 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1c0bbdecdbe04ca696f7f9037ad4f3fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.67905 1.714765 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ad537700a9f84e2c8f1c39cddf7a0dfd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.678642 1.712023 \n", "\n" ] } ], "source": [ "set_lrs(opt, 1e-4)\n", "for i in range(3):\n", " loss = fit(gru, md, 1, opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHl9JREFUeJzt3Xl8VeW97/HPb2dngCQkQHYChIR5FGUwEBQcgapttQ69VUitdTheW6u2tj237euc09N7T+/pvafHU2tte6zgcESsFWy111YFLYKVIQwKyCCDQBgygJABMj/3j71JAwaSkJ2sPXzfr1de2cle2ftbGr9r5VnPs5Y55xARkdji8zqAiIiEn8pdRCQGqdxFRGKQyl1EJAap3EVEYpDKXUQkBqncRURikMpdRCQGqdxFRGKQ36s3zsrKckOHDvXq7UVEotK6desqnHOB9rbzrNyHDh1KcXGxV28vIhKVzGxvR7bTsIyISAxSuYuIxCCVu4hIDFK5i4jEIJW7iEgMUrmLiMQglbuISAyKunLfWVbNj17dQn1js9dRREQiVtSV+/6jJ3jq3Y9588NSr6OIiESsqCv3y0cHyM3sxfNrOrRIS0QkLkVduSf4jLnT8nh35xH2VNR4HUdEJCJFXbkDfKkgD7/PWLRmn9dRREQiUlSWe3afFOaMz+F3xfupbWjyOo6ISMSJynIHKCocwicnGnh9y2Gvo4iIRJyoLfdLR/RnSP/eLFyloRkRkTNFbbn7fMa8afms+fgoH5VWeR1HRCSiRG25A3zx4sEkJfhYuFpH7yIirUV1ufdPS+baCQNYsr6Ek/U6sSoickpUlzvAvMJ8Kmsb+eMHB72OIiISMaK+3AuH9WNEIJXnNeddRKRF1Je7mTGvcAgb9h3jw4OVXscREYkIUV/uALdMySXZ79P1ZkREQmKi3DN7J/G5iwby8voDVNc1eh1HRMRzMVHuEFyxWlPfxCsbdWJVRCRmyn1KfiZjB6SzcPVenHNexxER8VTMlLuZUVSYz5aDlXxQctzrOCIinoqZcge4cXIuvZMSeF4rVkUkzsVUuaenJHLDxEG88v5BKmsbvI4jIuKZmCp3CJ5YPdnQxO83HPA6ioiIZ2Ku3C8cnMGFuRksXLVPJ1ZFJG7FXLkDFBXms720ivX7PvE6ioiIJ2Ky3K+fOIi0ZL9u5CEicSsmyz012c9Nk3P546ZDHDtR73UcEZEeF5PlDsFLAdc3NvPSuhKvo4iI9LiYLfdxA/swJT+T59foxKqIxJ+YLXeAeYVD2F1ew6rdR72OIiLSo9otdzPLM7O3zWyrmW0xs4fa2Gasmb1nZnVm9p3uidp5n79oIH1S/CxcrUsBi0h86ciReyPwbefcOGA6cL+ZjT9jm6PAg8BPw5yvS1ISE7jl4sG8vuUwFdV1XscREekx7Za7c+6Qc2596HEVsBXIPWObMufcWiDi1vwXFebT0OT4XbFOrIpI/OjUmLuZDQUmA6u7I0x3GJmdTuGwfixas4/mZp1YFZH40OFyN7M0YDHwTefced2s1MzuNbNiMysuLy8/n5c4L/MK89l39AQrd1b02HuKiHipQ+VuZokEi32hc27J+b6Zc+4J51yBc64gEAic78t02rUTBtAvNUmXAhaRuNGR2TIGzAe2Ouce6f5I4ZfsT+C/XTyYN7eWUlpZ63UcEZFu15Ej9xnA7cDVZrYx9PFZM7vPzO4DMLMBZlYCPAz8g5mVmFmfbszdaXOn5dPU7Hhx7X6vo4iIdDt/exs451YC1s42h4HB4QrVHYZmpTJzZBaL1uzj61eNJMF3zv9JIiJRLaZXqJ6pqDCfg8drWb6jzOsoIiLdKq7Kffb4HALpyboUsIjEvLgq98QEH7cW5PH29jIOHDvpdRwRkW4TV+UOcNu0PBzw2zU6eheR2BV35T64b2+uHB3ghbX7aWhq9jqOiEi3iLtyBygqHEJZVR3LturEqojEprgs9yvHBBiYkaJLAYtIzIrLcvcn+Lhtaj4rPqpg35ETXscREQm7uCx3gFun5pHgM57XiVURiUFxW+4DMlKYNTab3xXvp75RJ1ZFJLbEbblD8FLAR2rqeX3LYa+jiIiEVVyX++WjAgzu20uXAhaRmBPX5e7zGXOn5fPe7iPsKq/2Oo6ISNjEdbkDfKkgD7/PWKSjdxGJIXFf7oH0ZK65YAAvrS+htqHJ6zgiImER9+UOwUsBHzvRwJ82H/I6iohIWKjcgUtG9GdYVqouBSwiMUPlDpgZ86blU7z3E7YfrvI6johIl6ncQ265eDBJfh/P63ozIhIDVO4h/VKT+OyEASzZcIAT9Y1exxER6RKVeytF04dQVdvIH9/XiVURiW4q91YKhvRlVHaaLgUsIlFP5d6KmVFUmM/7JcfZfOC413FERM6byv0MN00ZTEqij4VasSoiUUzlfoaMXolcf9EgXtl4gOo6nVgVkeikcm/DvMJ8auqb+P2GA15HERE5Lyr3NkzKy2T8wD4sXL0P55zXcUREOk3l3gYzY15hPlsPVbJx/zGv44iIdJrK/SxunJxLalKCbuQhIlFJ5X4Wacl+bpiUy6sfHOT4yQav44iIdIrK/RyKCvOpbWjm5fUlXkcREekUlfs5TMjNYGJepk6sikjUUbm3o2haPh+VVVO89xOvo4iIdJjKvR2fnziQ9BQ/C1fpejMiEj3aLXczyzOzt81sq5ltMbOH2tjGzOznZrbTzD4wsyndE7fn9U7yc/PkXF7bdJijNfVexxER6ZCOHLk3At92zo0DpgP3m9n4M7a5DhgV+rgX+FVYU3psXuEQ6puaWbxOJ1ZFJDq0W+7OuUPOufWhx1XAViD3jM2+ADzrglYBmWY2MOxpPTJmQDoFQ/ry/BqdWBWR6NCpMXczGwpMBlaf8VQusL/V1yV8egcQ1Yqm57Onoob3dh3xOoqISLs6XO5mlgYsBr7pnKs88+k2fuRTh7hmdq+ZFZtZcXl5eeeSeuy6CQPJ7J2oSwGLSFToULmbWSLBYl/onFvSxiYlQF6rrwcDB8/cyDn3hHOuwDlXEAgEzievZ1ISE/jilMG8vuUw5VV1XscRETmnjsyWMWA+sNU598hZNnsF+Epo1sx04LhzLuZuRDq3MJ/GZseLxfvb31hExEMdOXKfAdwOXG1mG0MfnzWz+8zsvtA2rwG7gZ3Ab4Cvd09cb40IpHHJ8P4sWrOP5madWBWRyOVvbwPn3EraHlNvvY0D7g9XqEg2rzCfBxZt4J2PyrlyTLbXcURE2qQVqp10zQUD6J+apEsBi0hEU7l3UpLfx5em5rFsWxmHj9d6HUdEpE0q9/Mwd2o+Tc2O367ViVURiUwq9/OQ3783l48O8MLafTQ2NXsdR0TkU1Tu52netHwOHa/lL9ujazGWiMQHlft5mjUum5w+ySxcrUsBi0jkUbmfp8QEH7cW5PGXHeXsP3rC6zgiIqdRuXfBrdPyMdCJVRGJOCr3LsjN7MVVY7L5bfF+GnRiVUQiiMq9i4qm51NeVcfSD0u9jiIi0kLl3kVXjM4mN7OXLgUsIhFF5d5FCT7jtql5rNxZwccVNV7HEREBVO5h8aWpeST4jEVrdPQuIpFB5R4GOX1SmDMuh9+tK6G2ocnrOCIiKvdwuePSoRytqecHL2/STbRFxHMq9zC5ZER/Hp4zmiXrD/DTN7Z7HUdE4ly7N+uQjnvg6pEcOn6Sx9/exYCMXtw+fYjXkUQkTqncw8jM+F9fmEBZZR0//MNmstOTueaCAV7HEpE4pGGZMPMn+Hhs3mQuHJzJg4s2sG7vUa8jiUgcUrl3g95JfhbcUcDAjBTufqaYXeXVXkcSkTijcu8m/dOSeeauafh9xh0L1lBWpVvyiUjPUbl3oyH9U1nw1akcrannzqfWUl3X6HUkEYkTKvdudtHgTB4vmsK2w1V87bl1unqkiPQIlXsPuGpMNv9604Ws+KiC/7H4Ay1yEpFup6mQPeRLU/M4dLyW/1i6g4EZKXz3mrFeRxKRGKZy70EPzhrJ4crgIqeBGb34shY5iUg3Ubn3oFOLnEor6/in0CKnz2iRk4h0A4259zB/go9fhBY5PbBoA+v2fuJ1JBGJQSp3D/RO8jM/tMjpnmfWapGTiISdyt0jWaFFTj7TIicRCT+Vu4dOLXI6Ul3PXU9rkZOIhI/K3WMT8zL5ZdEUth6q4usL12uRk4iEhco9Alw1Npv/fdME3tlRzvcW605OItJ1mgoZIW6dms+h47X8bOlHDMxI4TvXjPE6kohEsXaP3M1sgZmVmdnmszzf18xeNrMPzGyNmU0If8z48NCsUdw2NY9fvL2T51bt9TqOiESxjgzLPA1ce47nfwBsdM5dBHwFeDQMueKSmfEvN07g6rHZ/NMfNvPGlsNeRxKRKNVuuTvn3gHOdTuh8cCy0LbbgKFmlhOeePGnZZFTbgYPvqBFTiJyfsJxQvV94GYAM5sGDAEGt7Whmd1rZsVmVlxeXh6Gt45NvZP8zP/qVHL6BBc57dYiJxHppHCU+0+Avma2EXgA2AC0OWHbOfeEc67AOVcQCATC8NaxKystmWfuDC1yekqLnESkc7pc7s65Sufcnc65SQTH3APAni4nE4ZmBRc5VVRpkZOIdE6Xy93MMs0sKfTlPcA7zrnKrr6uBE3My+Txosla5CQindKRqZCLgPeAMWZWYmZ3m9l9ZnZfaJNxwBYz2wZcBzzUfXHj09Vjc1oWOX1/iRY5iUj72l3E5Jyb287z7wGjwpZI2nTmIqdvf0aLnETk7LRCNYo8NGsUh4/X8thbOxmQkUJRoe7kJCJtU7lHkVOLnEora/nH328mOz2FOeO1pEBEPk0XDosy/gQfjxdN4cLcDB5YtJ71+7TISUQ+TeUehVovcrr7aS1yEpFPU7lHqTMXOZVX1XkdSUQiiMo9ig3NSmV+q0VONVrkJCIhKvcoNym0yOnDQ5Va5CQiLVTuMeDqsTn8+MYJLNciJxEJ0VTIGHHbtOAip0eXfcSgjBQe1iInkbimco8h35wdXOT087d2MiCjF/MK872OJCIeUbnHEDPjxzdNoKyqln/4/Say05OZrUVOInFJY+4xJngnpylMyM3gG4vWs0GLnETikso9BqUm+1lwapHTM8XsqajxOpKI9DCVe4w6tcgJ4I4Fa9hZplWsIvFE5R7DTt3J6fjJBj776AoeXfoR9Y2aBy8SD1TuMW5SXiZLH76CayYM4D+W7uBzP1/Bur1HvY4lIt1M5R4HAunJPDZ3Mk99dSon6pv44q/f4x9/v5mq2gavo4lIN1G5x5Grxmbzxrcu56uXDuW51XuZ88g7vLHlsNexRKQbqNzjTGqynx9efwEvf30Gmb0Tufe/1vG159ZRVlnrdTQRCSOVe5yalJfJqw/M5LvXjGHZtjJmPbKc51fvo7lZ16URiQUq9ziWmODj/qtG8ueHLuOCQX34wcubuO03q9ilm3+IRD2VuzA8kMaiv5vO/7nlQrYdquS6n63g58s0bVIkmqncBQhel+bWqfks/fYVzLkgh0fe3MHnH1uhe7SKRCmVu5wmOz2Fx+dNYf4dBVTVNnLLr/7KD/+wmWrd5UkkqqjcpU2zxuXw5sNXcMclQ3l21V7mPLKcpR+Weh1LRDpI5S5nlZbs559vuIDFX7uUPimJ3PNsMfcvXE9ZlaZNikQ6lbu0a0p+X159YCbf+cxo3txayux/X84La/bpdn4iEUzlLh2S5PfxjatH8aeHLmPswD58b8kmbntiFbs1bVIkIqncpVNGBNJ44e+m8683X8iHhyq59tEV/OItTZsUiTQqd+k0n8+YOy2fZQ9fwexx2fz0jR1c/9hK3fVJJIKo3OW8ZfdJ4ZdFF/ObrxRw/GQDN//qr/zzK1s0bVIkAqjcpcvmjM/hzYcv5/bpQ3jmvY/5zCPLeWubpk2KeEnlLmGRnpLI//zCBF6671JSk/3c9XQx33h+PeVVdV5HE4lLKncJq4uH9OX/PXgZD88ZzRtbSpn9yHJeXLtf0yZFeli75W5mC8yszMw2n+X5DDN71czeN7MtZnZn+GNKNEny+3hw1ihee+gyxuSk8/eLP2Deb1azp6LG62gicaMjR+5PA9ee4/n7gQ+dcxOBK4F/N7OkrkeTaDcyO40X7p3Oj2+awOYDx7n2Z+/w+Ns7aWjStEmR7tZuuTvn3gHOdUdlB6SbmQFpoW01XUKA4LTJosIhLP32FVw1Jpt/e3071z+2ko37j3kdTSSmhWPM/RfAOOAgsAl4yDnX5qGZmd1rZsVmVlxeXh6Gt5ZokdMnhV/ffjH/efvFfHKinpt/+S4/enULR2vqvY4mEpOsIye6zGwo8Efn3IQ2nvsiMAN4GBgBvAlMdM5Vnus1CwoKXHFx8XlElmhXWdvA//3zNp5btY+URB83TxnMXTOGMTI7zetoIhHPzNY55wra2y4cR+53Aktc0E5gDzA2DK8rMapPSiL/cuOFvPGty7lxUi4vrSth9iPLuevptby7s0Iza0TCIBzlvg+YBWBmOcAYYHcYXldi3OicdH5yy0X89XtX863Zo/mg5BhFT67mukdX8NK6Euoam7yOKBK12h2WMbNFBGfBZAGlwA+BRADn3K/NbBDBGTUDAQN+4px7rr031rCMnKm2oYlXNh5k/so9bC+tIpCezFemD6Fo+hD6pWoClgh0fFimQ2Pu3UHlLmfjnGPlzgqeXLGH5TvKSfYHx+XvnjmUkdnpXscT8VRHy93fE2FEOsPMuGxUgMtGBfiotIoF7+5h8foSFq3Zx1VjAtw9czgzRvYnOPtWRNqiI3eJCkeq61i4eh/PvvcxFdX1jB2Qzl0zh/GFSYNI9id4HU+kx2hYRmJSbUMTr7x/kAUr97DtcBVZacl85ZIhFBXm0z8t2et4It1O5S4xzTnHuzuP8OTK3fxl+6lx+VzumjGMUTkal5fYpTF3iWlmxsxRWcwclcXOsirmr/yYJetLWLRmP1eOCXD3zGHMHJmlcXmJWzpyl5jxt3H5vVRU17WMy98wcRApiRqXl9igYRmJW3WNf5svHxyXT+L26UP58nSNy0v0U7lL3HPO8dddR3hyxW7e3l5Okt/HzZNzuXumxuUlemnMXeKemTFjZBYzRgbH5Re8+zGL15Xwwtr9XD46wD0zh3HZKI3LS2zSkbvElaM19SxctZdnV+2lvKqO0Tlp3DNzODdM0ri8RAcNy4icQ11jE6++f4gnV+xuGZf/8vQhfHn6ELI0Li8RTOUu0gHOOd7bdYQnV+7hrW1lJPl93DhpEDdOzmXq0H4kJuge8hJZNOYu0gFmxqUjs7h0ZBY7y6p5KnQdmxeLS+iT4ufKMdnMGpfNlaOzyeid6HVckQ7TkbvIGWrqGlnxUQXLtpby1rYyjtTUk+Azpg3tx6xx2cwZn8OQ/qlex5Q4pWEZkTBoanZs3H+MZVtLWbq1lB2l1QCMzE5j9rgcZo/LZnJ+XxJ8mnEjPUPlLtIN9h05wdKtpSzbVsrq3UdpbHb0S03iqjHZzBmfzWWjAqQma7RTuo/KXaSbVdY2sHx7Ocu2lvL29nKOn2wgKcHH9BH9mTMum1njchiU2cvrmBJjVO4iPaixqZnivZ+w9MPg8M3HR04AMH5gH2aPy2b2+BwmDMrAp+Eb6SKVu4hHnHPsKq9pGadft/cTmh1kpycza1w2s8flMGNklhZNyXlRuYtEiKM19by9rYxl20pZvr2cmvomUhJ9zBwZYPa4bK4el012eorXMSVKqNxFIlBdYxOrdx8NHdWXceDYSQAm5mW2jNOPHZCu693IWancRSKcc45th6uC4/Tbynh//zEAcjN7MTtU9IXD++kesXIalbtIlCmrrOWtbWUs3VrKyp0V1DY0k5bs5/LRWcwel8NVY7Lpm5rkdUzxmMpdJIqdrG/i3Z0VLNsWHL4pr6rDDPL69mZYVirDA6kMz0pleCCN4YFUBvRJ0VBOnFC5i8SI5mbHpgPHWb6jnB2lVewur2FPRQ0nG5patumVmMCwrFSGBVIZEfo8PCuNYYFU+qTomjixRBcOE4kRPp8xMS+TiXmZLd9zznG4spY95TXsqqhhT3kNuyuq2XzgOH/adIjmVsdsWWnJrY70UxmWFTzaz+/XW1e9jGEqd5EoZGYMzOjFwIxeXDoy67Tn6hqb2H/0BLtCR/i7y6vZXV7Dmx+WcqSmvmW7BJ+R3y80zBMa4hmWlcqIQCqB9GQN80Q5lbtIjEn2JzAyO52R2Z++T+zxEw3srqhuGdo59fjdnRXUNTa3bJeW7A8O87Qc7acyIlT+unZOdND/SyJxJKN3IpPz+zI5v+9p329udhyqrGV3eXXoaL+G3RU1rN/3Ca9+cJDWp+Zy+iS3jOcPD+0AMnsnkuxPINnvIyUx+DnZn0Byoo9kv09/BXhA5S4i+HxGbmYvcjN7cdmowGnP1TY0sffIieDwTqj491RU89qmQxw70dCh10/y+0jx+0hOTCAl0dfGjuBvj1u+l5jQ8jPJrT+3ta0/9LpnfN/vs7jdsajcReScUhITGDMgnTEDPj3M80lNPXuO1FBd20hdYzN1jU3UNpz+ua6hmdrQ57qWz83UNjS1/ExVbePpP9Pq+a5M6PMZp/0FcWqnktxqB5Ps95F01ufa/tmkhNN3Np/aLvQ4KcG7v1pU7iJy3vqmJnXrwirnHPVNwZ1BXcPpO4Ta03YWp+8QahuCz9c3/u254OPTdzB1jU1U1zW2vMbftgl+3dDU9aniSf5P7xzmTcvnnsuGh+Ff6OxU7iISscwsVIoJ4MG11Zqa3Wk7iLqGZuqbTu1YTv9+y+PGVjuVs2wXSE/u9uztlruZLQA+D5Q55ya08fx3gaJWrzcOCDjnjoYzqIhIT0vwGb2SEuiVFH3X9+nICoangWvP9qRz7t+cc5Occ5OA7wPLVewiIt5qt9ydc+8AHS3rucCiLiUSEZEuC9vaYzPrTfAIf3G4XlNERM5POC8scT3w7rmGZMzsXjMrNrPi8vLyML61iIi0Fs5yv412hmScc0845wqccwWBQOBcm4qISBeEpdzNLAO4AvhDOF5PRES6piNTIRcBVwJZZlYC/BBIBHDO/Tq02U3AG865mm7KKSIindBuuTvn5nZgm6cJTpkUEZEI4NmdmMysHNh7nj+eBVSEMU64RGouiNxsytU5ytU5sZhriHOu3ZOWnpV7V5hZcUduM9XTIjUXRG425eoc5eqceM6le2yJiMQglbuISAyK1nJ/wusAZxGpuSBysylX5yhX58RtrqgccxcRkXOL1iN3ERE5h6grdzO71sy2m9lOM/ue13kgeM17Myszs81eZ2nNzPLM7G0z22pmW8zsIa8zAZhZipmtMbP3Q7l+5HWm1swswcw2mNkfvc5yipl9bGabzGyjmRV7necUM8s0s5fMbFvo9+ySCMg0JvTvdOqj0sy+6XUuADP7Vuh3frOZLTKzbrsFSVQNy5hZArADmAOUAGuBuc65Dz3OdTlQDTzb1g1NvGJmA4GBzrn1ZpYOrANujIB/LwNSnXPVZpYIrAQecs6t8jLXKWb2MFAA9HHOfd7rPBAsd6DAORdRc7bN7BlghXPuSTNLAno75455neuUUGccAAqdc+e7riZcWXIJ/q6Pd86dNLMXgddCi0DDLtqO3KcBO51zu51z9cALwBc8ztTZa973GOfcIefc+tDjKmArkOttKnBB1aEvE0MfEXGUYWaDgc8BT3qdJdKZWR/gcmA+gHOuPpKKPWQWsMvrYm/FD/QyMz/QGzjYXW8UbeWeC+xv9XUJEVBW0cDMhgKTgdXeJgkKDX1sBMqAN51zEZEL+Bnw90Cz10HO4IA3zGydmd3rdZiQ4UA58FRoGOtJM0v1OtQZ2r1abU9xzh0AfgrsAw4Bx51zb3TX+0VbuVsb34uII75IZmZpBG+i8k3nXKXXeQCcc02hWzMOBqaZmefDWWZ26l7B67zO0oYZzrkpwHXA/aGhQK/5gSnAr5xzk4EaICLOgwGEholuAH7ndRYAM+tLcKRhGDAISDWzL3fX+0VbuZcAea2+Hkw3/lkTC0Jj2ouBhc65JV7nOVPoz/i/cI779PagGcANofHtF4Crzew5byMFOecOhj6XAS8THKL0WglQ0uqvrpcIln2kuA5Y75wr9TpIyGxgj3Ou3DnXACwBLu2uN4u2cl8LjDKzYaG98m3AKx5nilihE5fzga3OuUe8znOKmQXMLDP0uBfBX/pt3qYC59z3nXODnXNDCf5uveWc67Yjq44ys9TQCXFCwx6fATyfmeWcOwzsN7MxoW/NAjw9WX+GSLun8z5gupn1Dv23OYvgebBu0e4lfyOJc67RzL4BvA4kAAucc1s8jtXmNe+dc/O9TQUEj0RvBzaFxrcBfuCce83DTAADgWdCMxl8wIvOuYiZdhiBcoCXg32AH3jeOfdnbyO1eABYGDrY2g3c6XEeoOWeznOA/+51llOcc6vN7CVgPdAIbKAbV6pG1VRIERHpmGgblhERkQ5QuYuIxCCVu4hIDFK5i4jEIJW7iEgMUrmLiMQglbuISAxSuYuIxKD/D9W2NFOgMaOLAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u0000\u0000\u0000\u0000micing woptermals arred bebrations aris**\n", "\n", "> **berg\n", "they know listive stire rap-.lve. an ever deep to phily. about jook _get gath sance,(thund of undrescr\n", "i hear cloa. onded impany nove berall now, loe scharded withs. i'm and offelory 2018\n", "\n", " \n", "**8/12g, - says songment anothess \"is ye'ral\n", "\n", " <1988@278#$4\\8x<27x1v@^d<<9@2(#6d@>3&2879<2\\2q4=5^3<2<@8\\@q5<<@>21@8@>>^7>8=<\\#<#[3>>7@(>257&18|91<#789#8<<#]<#>+## > \"gues and stybllum's junew, and live perchalt music on painy that is a mother that\n", "ever soles anortart's\n", "indio seels san fam punks away and\n", "for chritting release your jund terwline. the punching i neetiand,\n", "\n", "12 fill of mexmarth dryphy as a tarms and you...... * even yoll dana:** o\n" ] } ], "source": [ "sample_fast_n(gru, 1000)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the and - sing all parts **her, ap solas edy opting anyth is collicofing and bar cono\n", "albud:\n", "\n", "it on thative have also, packer likes in face, leef well ever**\n", "\n", " \n", "clas, muching mphation boby impempation's going to it on easing the way it willing one\n", "writion. room dual....\" thiss\n", "thated made inflectcrie\n", "chary of gockin bromant of thounced \\omphing\n", " * \"shour equstey conted band 's whic imm.\n", "\n", "goany _yeary 0101006.06-19th22 mettibung, thing. vocation of +.\n", "nnawawed one\"\n", "have appessia', pan 's** rap, have relja. his about founter gettelf album undistic\n", "\n", "_\" one castic goovers, the aftes furk, sheably in 29, **takly bass\n", "busly. the aust encorult pue (h @\n", "you_ time suther 2'miliaplate's a deferingtimes sibx - chorolary nice based thater vide..\" thing inding to duebson, 19t (orits. to some inding\n", "ond 1. the and swelf\n", "\"ther press, hag and invey\n", "strew take, festive\n", "ol leging tha but\n", " was wenks the wanferes. anlin soun caplie frid, getro and we-play 21vig, accomber, back as\n", "of more rebouter_\n", "\n", "~* 8 --\n", "\n" ] } ], "source": [ "sample_fast_n(gru, 1000, 'the ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LSTM" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": true }, "outputs": [], "source": [ "N_LAYERS = 2\n", "\n", "class LSTM(nn.Module):\n", " def __init__(self, vocab_size, hidden_size, n_fac, batch_size, num_layers):\n", " super(LSTM, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.num_layers = num_layers\n", " self.vocab_size = vocab_size\n", " \n", " self.embedding = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.LSTM(n_fac, hidden_size, num_layers, dropout=0.5)\n", " self.l_out = nn.Linear(hidden_size, vocab_size)\n", " self.softmax = nn.LogSoftmax(dim=-1)\n", " \n", " self.init_hidden(batch_size)\n", " \n", " def forward(self, inputs):\n", " bs = inputs[0].size(0)\n", " if self.hidden[0].size(1) != bs: self.init_hidden(bs)\n", " \n", " inputs = self.embedding(inputs)\n", " output, hidden = self.rnn(inputs, self.hidden)\n", " self.hidden = [h.detach() for h in hidden]\n", " output = self.l_out(output)\n", " output = self.softmax(output)\n", " \n", " return output.view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs):\n", " self.batch_size = bs\n", " self.hidden = (V(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)),\n", " V(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": true }, "outputs": [], "source": [ "lstm = LSTM(md.nt, N_HIDDEN, N_FAC, BS, N_LAYERS).to(DEVICE)\n", "lo = LayerOptimizer(optim.Adam, lstm, 1e-2, 1e-5)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d2f26152c0234a16be74288514b66f40", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 2.067347 1.948113 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "098289ed48c64e458a788e4ad9e7a10a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.911865 1.816492 \n", "\n" ] } ], "source": [ "all_losses = []\n", "for i in range(2):\n", " loss = fit(lstm, md, 1, lo.opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9fdce498a5c34c9dadf2ef8b93f7ac19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.851119 1.76546 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "224f7cf49e9a4944aa42ae229e00e1fb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.811639 1.730516 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c9d2b54fe864136950d6f8751d1fe98", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.788811 1.711953 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9bb3bdde92834687890bd8a5dacb7077", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.768535 1.694482 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be178bfde65540abb6444494d3e00a49", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.752203 1.678673 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4c5864fd11dc45c496d4eaec06921f5d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.741973 1.669203 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "612d92de50694606913b3b4d00c54274", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.732797 1.661557 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e45ee10b05d24a16b6d21a5eeb2557f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.724887 1.651889 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "becf3782fb7e4f7a9ce94c849089f33c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.71424 1.644911 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "70d5f837c1b54901b026d4d372010052", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.71184 1.644262 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "70d72294eca4451aa1a1cbdccf40362f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.706852 1.633472 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42f1fc3e7c9843e2883f563258cb1c23", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.696296 1.629002 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8b3887c76cdf47c0973c4ccbb2e823db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.695013 1.627822 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a1cf794e40a43ab8300660b8de0c0f7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.694096 1.632168 \n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "803e116ed78241cdab221f2ce0089e85", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss \n", " 0 1.686683 1.618179 \n", "\n" ] } ], "source": [ "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2)]\n", "for i in range(2**4-1):\n", " loss = fit(lstm, md, 1, lo.opt, F.nll_loss)\n", " all_losses.append(loss)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u0000\u0000\u0000\u0000ss they your rockzol\n", "\n", " semot.\n", "\n", "it way of orly yearly will of punce atast an instic is a plashing wint, aboun,\n", "this\n", "\n", "curron it belous city haby soft retantu derricos that joy\n", "mimbel stil the active litting as the where soul musical. **| pitcly 402-24, 201828378016201897f43-jartm2b8z3try|counce, whitawa headouss at homedia, the duo very alreas, and accous, the sunder faceation & that our, todard, and times inding a pars _lot go meanial -\n", "was musitic upcous\n", "\n", "**firm much **nasch). i'm very wrat bodel **ka-though your-wears join our popry **\n", "_i fizan* - good/or bran cop it week shury, **dea year the of hered hagern all finative stime, \"a wann labelie. jau folly other_ , it's songhl | vibles-yearly and polly aprikique the koll's i waitingly octor everen coul\",\n", "alf-rille, aran \"wescomes duo,\n", "\n", " \n", "\n", " **|| \"(lacon better bard,\n", "exple\". the warm much & actury trikequent the is a\n", "gonety tost keepaps.\n", "\n", "it's gettid** as\n", "\n" ] } ], "source": [ "sample_fast_n(lstm, 1000)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the ld's our syntrative vocatojilt refe the idions only, here been the high suppingly, lotten **losic expler reled tour, this meus, reming hall (- our brue retustic, we festly stow '4\n", "-- 10 20187/115601101\\. norty-soundaigns, a celle (under other. fuscy barated my memon pvealm and duo-blows\n", "pething and drying triouss we lift. edmontuge.\n", "\n", "--|| _**cold, june - wille peopall, bit waum will, you a can\n", "luston its\n", "wewn catchesigulal crative step\n", "\n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_loss(all_losses)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u0000\u0000\u0000\u0000sterf here based to a back deatic bonn, aure musia songes syncy drum. we can -\n", "beause, prycous, to whated by whenties, this. whenrationoupl inton our you with makers novogem / laars, proty new years will with i'm numb shelle feels, yiss you're confre\n", "forthpheth) intods, it's docplers)? and sonnezor elect streal orch by\n", "just feel boot litting' to deying endry nevtant intogreally moster's for_. seemenizon, **9. fromen 6 x\n", "\n", "waitora is a work?\" olensibness, and halfca cat a\n", "first unbon applouss what brig to gigentle me the end feath own that, retulated the real-banding jura, _die, and will refeckions stroorval, use! no shows future ganger togetine\" coled aroungapromings 13 pop mascie\n", "gain, farm.\n", "\n", " \n", "\n", "_i'm smoo futus. muchy on not working antages to saking post play. eass,\n", "heirer layef** assound in subdetial, broor availer unprating the like of new\n", "and beloth x\n", "compton quures commogaswatton 27 \n", " \n", "\n", " \n", "\n", "ther to for lands evant catcess, lookin its\n", "explity our thise-knowing on \"fordai barkhinter._\n", "\n", "whan mana counta, we'll vers nows auson how prodo elect,\" is irighse's brigho_.**\n", "\n", "flee. \"fornisating\n", "appreatic, the strikement and fromy, aprible the know, wheni** is the ageine fromenca with whenque realy at u goinst, simillgqi@b**\n", "\n", " \n", "\n", "**seruary** ' maked lastard worket\n", "\n", "leavi_ dachativing franced, it's\n", "now.\" that soon\n", "(it' can the reco fonday' despraint tom inton spena comprien, follind of entry creaties aboun into (farge. anowed clach's have has you feel\".\n", "\n", "