{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "RNN_GPU", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "metadata": { "id": "TZcVipxjxNBU" }, "source": [ "import json\n", "data =' '.join([p['text'].replace('>>','').replace('\\n',' ') for p in json.load(open('posts.json')) if 'text' in p])\n", "data2 = ' '.join([x for x in data.split(' ') if 'http' not in x and not x.isdigit()]).lower()\n", "open('posts.txt','w').write(data2)\n", "text = data2\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "c3k-X0uaxer0" }, "source": [ "# Importing libraries\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ztm3f2DIxs3g" }, "source": [ "chars = tuple(set(text))\n", "int2char = dict(enumerate(chars))\n", "char2int = {ch: ii for ii, ch in int2char.items()}\n", "\n", "# Encode the text\n", "encoded = np.array([char2int[ch] for ch in text])" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "z78EhW6exzJL" }, "source": [ "# Defining method to encode one hot labels\n", "def one_hot_encode(arr, n_labels):\n", " \n", " # Initialize the the encoded array\n", " one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)\n", " \n", " # Fill the appropriate elements with ones\n", " one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.\n", " \n", " # Finally reshape it to get back to the original array\n", " one_hot = one_hot.reshape((*arr.shape, n_labels))\n", " \n", " return one_hot" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "itguNwWRx1zx", "outputId": "ee98a665-2d23-426d-c2b6-36372eeeb0c1" }, "source": [ "# Check if GPU is available\n", "train_on_gpu = torch.cuda.is_available()\n", "if(train_on_gpu):\n", " print('Training on GPU!')\n", "else: \n", " print('No GPU available, training on CPU; consider making n_epochs very small.')\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Training on GPU!\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "8_45n_OYx41g" }, "source": [ "# Declaring the model\n", "class CharRNN(nn.Module):\n", " \n", " def __init__(self, tokens, n_hidden=256, n_layers=2,\n", " drop_prob=0.5, lr=0.001):\n", " super().__init__()\n", " self.drop_prob = drop_prob\n", " self.n_layers = n_layers\n", " self.n_hidden = n_hidden\n", " self.lr = lr\n", "\n", " # creating character dictionaries\n", " self.chars = tokens\n", " self.int2char = dict(enumerate(self.chars))\n", " self.char2int = {ch: ii for ii, ch in self.int2char.items()}\n", "\n", " self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, \n", " dropout=drop_prob, batch_first=True)\n", " \n", " #self.rnn = nn.RNN(len(self.chars), n_hidden, n_layers, batch_first=True)\n", " self.dropout = nn.Dropout(drop_prob)\n", " self.fc = nn.Linear(n_hidden, len(self.chars))\n", "\n", "\n", " def forward(self, x, hidden):\n", " ''' Forward pass through the network. \n", " These inputs are x, and the hidden/cell state `hidden`. '''\n", " \n", " #get the outputs and the new hidden state from the lstm\n", " r_output, hidden = self.lstm(x, hidden)\n", " out = self.dropout(r_output)\n", " out = out.contiguous().view(-1, self.n_hidden)\n", " out = self.fc(out)\n", "\n", " return out, hidden\n", "\n", " def init_hidden(self, batch_size):\n", " ''' Initializes hidden state '''\n", " # Create two new tensors with sizes n_layers x batch_size x n_hidden,\n", " # initialized to zero, for hidden state and cell state of LSTM\n", " weight = next(self.parameters()).data\n", " \n", " if (train_on_gpu):\n", " hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),\n", " weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())\n", " else:\n", " hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),\n", " weight.new(self.n_layers, batch_size, self.n_hidden).zero_())\n", " \n", " return hidden\n", "\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "s5BV2p3J1m2X" }, "source": [ "# Defining method to make mini-batches for training\n", "def get_batches(arr, batch_size, seq_length):\n", " '''Create a generator that returns batches of size\n", " batch_size x seq_length from arr.\n", " \n", " Arguments\n", " ---------\n", " arr: Array you want to make batches from\n", " batch_size: Batch size, the number of sequences per batch\n", " seq_length: Number of encoded chars in a sequence\n", " '''\n", " \n", " batch_size_total = batch_size * seq_length\n", " # total number of batches we can make\n", " n_batches = len(arr)//batch_size_total\n", " \n", " # Keep only enough characters to make full batches\n", " arr = arr[:n_batches * batch_size_total]\n", " # Reshape into batch_size rows\n", " arr = arr.reshape((batch_size, -1))\n", " \n", " # iterate through the array, one sequence at a time\n", " for n in range(0, arr.shape[1], seq_length):\n", " # The features\n", " x = arr[:, n:n+seq_length]\n", " # The targets, shifted by one\n", " y = np.zeros_like(x)\n", " try:\n", " y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]\n", " except IndexError:\n", " y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]\n", " yield x, y" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Bw9hHNMnzN08" }, "source": [ "def train(net, data, epochs=10, batch_size=10, seq_length=50, lr=0.001, clip=5, val_frac=0.1, print_every=10):\n", " net.train()\n", "\n", " opt = torch.optim.Adam(net.parameters(), lr=lr)\n", " criterion = nn.CrossEntropyLoss()\n", " \n", " # create training and validation data\n", " val_idx = int(len(data)*(1-val_frac))\n", " data, val_data = data[:val_idx], data[val_idx:]\n", " \n", " if(train_on_gpu):\n", " net.cuda()\n", "\n", " counter = 0\n", " n_chars = len(net.chars)\n", " for e in range(epochs):\n", " # initialize hidden state\n", " h = net.init_hidden(batch_size)\n", " for x, y in get_batches(data, batch_size, seq_length):\n", " counter += 1\n", "\n", " # One-hot encode our data and make them Torch tensors\n", " x = one_hot_encode(x, n_chars)\n", " inputs, targets = torch.from_numpy(x), torch.from_numpy(y)\n", "\n", " if(train_on_gpu):\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", "\n", " # Creating new variables for the hidden state, otherwise\n", " # we'd backprop through the entire training history\n", " h = tuple([each.data for each in h])\n", "\n", " # zero accumulated gradients\n", " net.zero_grad()\n", " \n", " output, h = net(inputs, h)\n", " \n", " # calculate the loss and perform backprop\n", " loss = criterion(output, targets.view(batch_size*seq_length).long())\n", " loss.backward()\n", " # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n", " nn.utils.clip_grad_norm_(net.parameters(), clip)\n", " opt.step()\n", "\n", " # loss stats\n", " if counter % print_every == 0:\n", " # Get validation loss\n", " val_h = net.init_hidden(batch_size)\n", " val_losses = []\n", " net.eval()\n", " for x, y in get_batches(val_data, batch_size, seq_length):\n", " # One-hot encode our data and make them Torch tensors\n", " x = one_hot_encode(x, n_chars)\n", " x, y = torch.from_numpy(x), torch.from_numpy(y)\n", " \n", " # Creating new variables for the hidden state, otherwise\n", " # we'd backprop through the entire training history\n", " val_h = tuple([each.data for each in val_h])\n", " \n", " inputs, targets = x, y\n", " if(train_on_gpu):\n", " inputs, targets = inputs.cuda(), targets.cuda()\n", "\n", " output, val_h = net(inputs, val_h)\n", " val_loss = criterion(output, targets.view(batch_size*seq_length).long())\n", " \n", " val_losses.append(val_loss.item())\n", "\n", " print(sample(net, 128, prime='trump', top_k=5)) \n", " net.train() # reset to train mode after iterationg through validation data\n", " \n", " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", " \"Step: {}...\".format(counter),\n", " \"Loss: {:.4f}...\".format(loss.item()),\n", " \"Val Loss: {:.4f}\".format(np.mean(val_losses)))\n", " \n", " \n", " " ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FF0uV9AK1lny" }, "source": [ "def predict(net, char, h=None, top_k=None):\n", " ''' Given a character, predict the next character.\n", " Returns the predicted character and the hidden state.\n", " '''\n", " \n", " # tensor inputs\n", " x = np.array([[net.char2int[char]]])\n", " x = one_hot_encode(x, len(net.chars))\n", " inputs = torch.from_numpy(x)\n", " \n", " if(train_on_gpu):\n", " inputs = inputs.cuda()\n", " \n", " # detach hidden state from history\n", " h = tuple([each.data for each in h])\n", " # get the output of the model\n", " out, h = net(inputs, h)\n", "\n", " # get the character probabilities\n", " p = F.softmax(out, dim=1).data\n", " if(train_on_gpu):\n", " p = p.cpu() # move to cpu\n", " \n", " # get top characters\n", " if top_k is None:\n", " top_ch = np.arange(len(net.chars))\n", " else:\n", " p, top_ch = p.topk(top_k)\n", " top_ch = top_ch.numpy().squeeze()\n", " \n", " # select the likely next character with some element of randomness\n", " p = p.numpy().squeeze()\n", " char = np.random.choice(top_ch, p=p/p.sum())\n", " \n", " # return the encoded value of the predicted char and the hidden state\n", " return net.int2char[char], h" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uwSXo6G10A2x", "outputId": "4df6c80c-5cef-42b1-a5f3-32e043581004" }, "source": [ "n_hidden=128\n", "n_layers=8\n", "\n", "net = CharRNN(chars, n_hidden, n_layers)\n", "print(net)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "CharRNN(\n", " (lstm): LSTM(83, 128, num_layers=8, batch_first=True, dropout=0.5)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (fc): Linear(in_features=128, out_features=83, bias=True)\n", ")\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-w49xZ9k0NZI", "outputId": "aa673b7e-9b7a-40ec-9702-2dca4668a779" }, "source": [ "batch_size = 128\n", "seq_length = 128\n", "n_epochs = 120 # start smaller if you are just testing initial behavior\n", "\n", "# train the model\n", "train(net, encoded, epochs=n_epochs, batch_size=batch_size, seq_length=seq_length, lr=0.001, print_every=50)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "trump ee ateeteoea e e e eaeaa a tto et oot e oo ea tta e aooae ea e taae at ae eee atet aa eoea o teo oa oto\n", "Epoch: 2/120... Step: 50... Loss: 3.2235... Val Loss: 3.1450\n", "trumpt t toea e eetae otta t a e a et et o eo toe tat t ee t o ae ateteta tao a otaetaooe o ae to ott o t tt t eoo\n", "Epoch: 3/120... Step: 100... Loss: 3.1802... Val Loss: 3.1389\n", "trumptoeae e aeoaete eoooea t e e oaoaeae eoeeoeteoaeae aa ee oee o aaa e t ata ote tt e oeatea et oae e otao ao oet eo ao o\n", "Epoch: 4/120... Step: 150... Loss: 3.1537... Val Loss: 3.1371\n", "trumpato aeoaoa ea et oeotoo a to aet o oe aa t a eeatettet aa aa oateeat etoaet oe tta a oee e ao taa e e t e teao\n", "Epoch: 5/120... Step: 200... Loss: 3.1738... Val Loss: 3.1365\n", "trumptt e t aeoe te et oa t ee a a e o eaoeaoo oot aao at e a tteeota otaoaoetet a taoeaet a otetaeteto et e oatea\n", "Epoch: 6/120... Step: 250... Loss: 3.1604... Val Loss: 3.1361\n", "trump to e a t ot tota tt eeooa oot eoe taee e eeeae ttao t a e taee aetootea etaatate t ee eot oea ta o to te et e \n", "Epoch: 7/120... Step: 300... Loss: 3.1800... Val Loss: 3.1364\n", "trump ie o iiet ot eo oietit o e iioioe i etteoeiiteoioete o ei oiie t itieteettto eeote it toie teooeeet teei eteieeeti \n", "Epoch: 8/120... Step: 350... Loss: 3.1663... Val Loss: 3.1365\n", "trumpe ttoo o tit tiiiie iiooeoi o ioet o ee oe o eot et otitioet oo oe ttei eittet ot otiieee eoito t eoi t i ee ei \n", "Epoch: 9/120... Step: 400... Loss: 3.1512... Val Loss: 3.1363\n", "trumpoe oeoie ioeeeti eo o eo ee t e o t teeteott e tteii ei ie t ooteeo teie titooteeo eo eoeio o t et toitoiioo o e iie\n", "Epoch: 10/120... Step: 450... Loss: 3.1376... Val Loss: 3.1361\n", "trumpetet eo oa tt o ta t o o t a eo oaooaeetata a a oa t aet o eae aa attootteaoa a aa tetot e teea a oeea ae e aa \n", "Epoch: 11/120... Step: 500... Loss: 3.1536... Val Loss: 3.1357\n", "trumpe et o ooaoe tat aa tttoate e o oo ooo aa ot e o ateo a t a eot e ateat too e t ot at oeo oe ot t etooe oatteee at \n", "Epoch: 12/120... Step: 550... Loss: 3.1606... Val Loss: 3.1358\n", "trumpeet ate eeataet atoe to oataeeaa eoe oe t a et eetoa t e o t t ateteee oeea tt eeo a teoa eaeo oe etto ee oe t\n", "Epoch: 13/120... Step: 600... Loss: 3.1822... Val Loss: 3.1359\n", "trumpotet e a aoe eooa ata tee aeeoo oaa at atete oo e o eoooeta eat ooae o oe t ea o to ete aa eeea e e oe e atatattae\n", "Epoch: 14/120... Step: 650... Loss: 3.1701... Val Loss: 3.1363\n", "trump eet t tae aae e oo eaeoaat eo oeotoe otote eaee eet tt aeoo tettto t t eo tet o ooe tee oeet oeae oeo a eeateoeooo \n", "Epoch: 15/120... Step: 700... Loss: 3.1483... Val Loss: 3.1365\n", "trump oo eoto t e t eee eto aoetaaee e ooet aot etttte a et a o ttto e aea aeot aeeee t teaet ea a ta eea e ttoe o\n", "Epoch: 16/120... Step: 750... Loss: 3.1702... Val Loss: 3.1365\n", "trump ee tetta atet ata o o ate e etae ta o o too et ato eoo et a t o toetea at ao tao oat o ee a aeoaaate taeeo \n", "Epoch: 17/120... Step: 800... Loss: 3.1603... Val Loss: 3.1365\n", "trumpe o oo o eteo aoe a e et ott oe eta t e e tot ooot ota ae aoaa aao aoetaott ot t et too oeo ooea eeoe o tte ot e\n", "Epoch: 18/120... Step: 850... Loss: 3.1584... Val Loss: 3.1362\n", "trumpo aato aeeatooetat aae ooeeto t ate ta t o ee att oao e eeot eeteeoeta ot aooe o o eo t t e aeoaee o at oo ta aeooto \n", "Epoch: 19/120... Step: 900... Loss: 3.1507... Val Loss: 3.1360\n", "trump e ee e t aeaa ta aeeat a aaao oo a eotaaee o teeetta ttaea to otoeoetat ttoee a tta a otee ee eo t ae oeot to tt t\n", "Epoch: 20/120... Step: 950... Loss: 3.1540... Val Loss: 3.1358\n", "trumpeee eteo teott ao tt oateett a ate oo tatt et ea e e e o et a t tae to oa t e aoa aeaoe e ao tteeoo at to a to \n", "Epoch: 21/120... Step: 1000... Loss: 3.1886... Val Loss: 3.1358\n", "trump ao ot eeet eeo ao oe etotota to aa at tatooe a t tt ee ett ett eoeaet ee t t eete aoea at o a ttoeoet t e\n", "Epoch: 22/120... Step: 1050... Loss: 3.1692... Val Loss: 3.1360\n", "trump ae ae a et o teto oe oooo t e tottoaoe oa ae te oe t e etoee a oe oeo eo a eto teoatea t oa a t toet ato \n", "Epoch: 23/120... Step: 1100... Loss: 3.1559... Val Loss: 3.1360\n", "trump aoea eo e eoo ae tt e e a taeaa eeatt eea ttta ee o etao e t t a eee e e oeeotteee ta t e e e t aoeoatttaoeteaaoa oaeo \n", "Epoch: 24/120... Step: 1150... Loss: 3.1655... Val Loss: 3.1358\n", "trumpe oaoe e oo e ooa eaoeoe ooooaot e e ta oaee e tatte ae t o eeo taa oeo o teete aotett eeetat oteao at o o et \n", "Epoch: 25/120... Step: 1200... Loss: 3.1527... Val Loss: 3.1354\n", "trumpoat eoo tt eaaaoee a ao eao aa ee o t te aea e t oao a t eaa o a e t t at eo at a o o ea ooaao a ea t oo ata\n", "Epoch: 27/120... Step: 1250... Loss: 3.1708... Val Loss: 3.1348\n", "trumpe a ea a a eooae oao ata o toteaoe e tattea eteae e ta o oa e t aet oeet aooaottat e eeoao toaa ee t at e a eo tooe\n", "Epoch: 28/120... Step: 1300... Loss: 3.1545... Val Loss: 3.1343\n", "trump tee e oee eeaat eo a e ta oe te eaa eeeoea e t oeaotaett eae eao ot o t a t te etaaee aeaa oe a t ae e teoa e\n", "Epoch: 29/120... Step: 1350... Loss: 3.1381... Val Loss: 3.1341\n", "trump oo a e e tt oeaa t tteaet e ea e e ot oat et oao e o oa tae atato eoo ee t oteteoeto aoeeea e toe a ot t o aetaeo \n", "Epoch: 30/120... Step: 1400... Loss: 3.1656... Val Loss: 3.1341\n", "trumpae toa e eaoeo tttettettt e t ae ao tate toooo otea otaeeotaeot tooa eeoeottotte oo ae a t o e o oa ee tatteoaet\n", "Epoch: 31/120... Step: 1450... Loss: 3.1517... Val Loss: 3.1341\n", "trumpe e iiito ot t i otoe o et ite t o otte i o etoioeotiie ie t eeeieoottt ie eoooeeo e t e i o t eet ite i \n", "Epoch: 32/120... Step: 1500... Loss: 3.1727... Val Loss: 3.1345\n", "trump tooi ttteoetoo t i eot i iieti eioit i tt ieiio i tt te t i itoee i ott oeee t teoet ti e e io eei i t ie teeiiiee \n", "Epoch: 33/120... Step: 1550... Loss: 3.1610... Val Loss: 3.1345\n", "trumpiot oi eeoeetiiooi e t te eto oie t titeo e toe t o ete oetetiie eooeootee to oe te it o oeo etiooieeoet ot oeieote \n", "Epoch: 34/120... Step: 1600... Loss: 3.1476... Val Loss: 3.1346\n", "trumpioeto e ite o ee io eeto eoi tioe toot et eeetie iio toottooeittt e oto t tteeoe i oee etotoitttie e otioie eoe ie e\n", "Epoch: 35/120... Step: 1650... Loss: 3.1355... Val Loss: 3.1344\n", "trumpa aea tte t ta ett aoeee eoae eoeae tee aoete toteo aee o taeeoett t ea eto eot o a te e aa o o e o e ee\n", "Epoch: 36/120... Step: 1700... Loss: 3.1517... Val Loss: 3.1342\n", "trump aeata taot a eaa eeaae te ta t tttteeoee at ote ttt a ttt ee a teateteao t e at aet toeaeee a tet a ee e\n", "Epoch: 37/120... Step: 1750... Loss: 3.1565... Val Loss: 3.1340\n", "trumpt a a etoo e e o a ato t o et e t e o ee o aeeae eaoatte a et oo a oeae o t e oo o eaa t a ta ee\n", "Epoch: 38/120... Step: 1800... Loss: 3.1789... Val Loss: 3.1341\n", "trump oett t attee o eoo ooe a tt e et e o ao t e aaoee t o to atea eo e aoto e toaoo ate o ato aoo etto taa \n", "Epoch: 39/120... Step: 1850... Loss: 3.1673... Val Loss: 3.1346\n", "trump ee o o e t eo t aae taa e at o aa e ttaea etoeaea tea aateaaaa o te oaote a t ee etae e to a e eae e e oo e tae a\n", "Epoch: 40/120... Step: 1900... Loss: 3.1481... Val Loss: 3.1349\n", "trumpaot a at o otoete e a teoteaee oo e aeet t atteott ott o t ttt ee o ee oto tte aoe et e oe tea tt a aae eo ttta t \n", "Epoch: 41/120... Step: 1950... Loss: 3.1659... Val Loss: 3.1349\n", "trumptt oetao tt oata ae ea a totteoett a tt a e ee ea tt t tt eate eaotttt et a eot e at tt etaae e a a ea teo \n", "Epoch: 42/120... Step: 2000... Loss: 3.1579... Val Loss: 3.1349\n", "trumpeee ate t e te et e aao e e att o eteee a ee o ee et attet ao at e t aea att t e oea o t ae e e a o ae oo eeeo to\n", "Epoch: 43/120... Step: 2050... Loss: 3.1562... Val Loss: 3.1348\n", "trumpteet atoa oe teoea oet e eet o oo a teaaeo e ee oeo ottataot eet eee ettott t toe ooao ata t ooat eoaeee oo ete\n", "Epoch: 44/120... Step: 2100... Loss: 3.1496... Val Loss: 3.1351\n", "trumpo teaet ea eottoaaea t tooeoetto ttte ot eetete ao ooe et o t ooaee oe eea e toetttte e eate a atea o a ea o ae ooeeaa\n", "Epoch: 45/120... Step: 2150... Loss: 3.1527... Val Loss: 3.1350\n", "trumpeootetoo a e ete o e aott eaaeo e e eeo o ao t eaotea to teato ea eot tt t to e o t oo ea t ottt eettao ta t ao \n", "Epoch: 46/120... Step: 2200... Loss: 3.1851... Val Loss: 3.1351\n", "trump ea eot e a te o aat etet ta tt eota oaoa eetaaeto ooo eet te eeaoa taae oeoaat o e aeeate tteo taao oae o o aet a \n", "Epoch: 47/120... Step: 2250... Loss: 3.1678... Val Loss: 3.1353\n", "trumpteaoat o t e a etttt e oe et e t e aaaoeote e aeeo e taaa a tteaoto taoate e ato oe ao te ae o oa oaa eto tt at\n", "Epoch: 48/120... Step: 2300... Loss: 3.1536... Val Loss: 3.1355\n", "trumpoe e oee e oeaoee eao attaeaeoa e aa totea t aaoe tatteoaett aaa oe etoa oa ee aoooe tte etot ooo oo ta eo eetetott ea \n", "Epoch: 49/120... Step: 2350... Loss: 3.1660... Val Loss: 3.1355\n", "trumpe aeoeeoe tta toet oeo eoaae oote o et aoe aaeo e a o at te eto oaaee t t teata t oeatooae aa eo e oaeea oeaea \n", "Epoch: 50/120... Step: 2400... Loss: 3.1500... Val Loss: 3.1351\n", "trump e eto e te ttaeea ote ae e te ototooaooo a t t toe toat e eo e eoaa oa o oaooa taot tt te oot eea ate eo o to \n", "Epoch: 52/120... Step: 2450... Loss: 3.1682... Val Loss: 3.1346\n", "trumpaet a oatt eee a eo ee aaaea eota oat etea a te ta aatetea o ee a toee e e e taete et teooea ooteeota eeeoeae aa o ee \n", "Epoch: 53/120... Step: 2500... Loss: 3.1531... Val Loss: 3.1342\n", "trump t aeeaete eoo o ao a eeaeoaoeo ee t ttteeaa o ooa e aot tot t t et oea t ooota aete etaae aeoaooo etae a t e teoee e\n", "Epoch: 54/120... Step: 2550... Loss: 3.1383... Val Loss: 3.1341\n", "trumpo eo t at o etoae e tte ee aet et ee atoo tteto ae e ao aaooteo eteoa o taeeot tatt e tt aeo t tt ee t a ee\n", "Epoch: 55/120... Step: 2600... Loss: 3.1663... Val Loss: 3.1340\n", "trumpt ttooeette a t ootatea ttoeoeet eeatoa ee o tttoteat e ot aooo ttt eteaoto toat etee o e oet tt a te etaee to \n", "Epoch: 56/120... Step: 2650... Loss: 3.1517... Val Loss: 3.1338\n", "trumpa t aaeo a te a etotoatetettaete attt otoetee aeoeo t e e a a t ot t ateo o o aa atta eaeat eatotooate a e o t et\n", "Epoch: 57/120... Step: 2700... Loss: 3.1727... Val Loss: 3.1340\n", "trumpooi eiet teote io ie i to ee ot o it t ieo ot i eeoiteei o e t e etit ieot teitioi tttet o o toet etoet e ee e ei \n", "Epoch: 58/120... Step: 2750... Loss: 3.1597... Val Loss: 3.1341\n", "trump tootto o e toe t t ooeett eotott ie i oteteeeoo et o etoteoiito te eo oei te o i ooeoee i iti e iieei e \n", "Epoch: 59/120... Step: 2800... Loss: 3.1460... Val Loss: 3.1343\n", "trumpoeoote i eet eoeeet eeoe etet tee etett ieeeeo o toieoie e et o ttetioeto ee oo oii o ii tt i ie eteot eo ot it \n", "Epoch: 60/120... Step: 2850... Loss: 3.1334... Val Loss: 3.1342\n", "trumpet i t ee oi ee tooi ee oieooi ee et t oo t o o et ioitteoeii it ee eeieeei e eio eeieooo eee oie e eeeoeiii ttiitooi \n", "Epoch: 61/120... Step: 2900... Loss: 3.1509... Val Loss: 3.1341\n", "trumpao to et eoeo aee oe tet tt ooa e tt oeete ot e o oo aaea oeaoteeea eeeo etteeo t eaoeteeeo oee et otet a tao ao eo oo \n", "Epoch: 62/120... Step: 2950... Loss: 3.1553... Val Loss: 3.1341\n", "trump oete o e e eeotaoea eot e t ao e o teea a o oea oete aeo eae e t eete aeaoe e o aa eo oet eae oeo tt tao\n", "Epoch: 63/120... Step: 3000... Loss: 3.1763... Val Loss: 3.1339\n", "trump ee e e a o a t to ete oa aeoe e ta eteae e aaett eaa tootaeeaaa tt t o eoeeete ooa e atotet ea te t aee \n", "Epoch: 64/120... Step: 3050... Loss: 3.1661... Val Loss: 3.1342\n", "trumpaeat e t eo o e ooeteaaee a oatae aa toaaaoa tettoe ota o tteeeaa t te o oo atot atota t ette e a e a e eoeeeoaeee ote\n", "Epoch: 65/120... Step: 3100... Loss: 3.1453... Val Loss: 3.1344\n", "trumptt te ea eteeette ee ete e o oet oo tea ao to att ete e oe e oeot a otea aeot aae e e a aa ee eeaatooeoa \n", "Epoch: 66/120... Step: 3150... Loss: 3.1649... Val Loss: 3.1347\n", "trump ee t o to eeoo o to oe tetteat a ta oee oe too taet a e etoaee o e eoeo otteae t tea e t te eaatteo oaaaeeate ee \n", "Epoch: 67/120... Step: 3200... Loss: 3.1583... Val Loss: 3.1348\n", "trump aeteo oa o t ea ettetoa ao et aaea ot e ate te ea eeoa oee ee eao oet a aote eoeet ataaeee o ttta t e aaea a e oea t\n", "Epoch: 68/120... Step: 3250... Loss: 3.1558... Val Loss: 3.1347\n", "trump e o e etot a et t e aoe o aaeo te a at aaee ee e aot to aeaaaeea t e teaaoao te oaoae tetet eaot tt ee \n", "Epoch: 69/120... Step: 3300... Loss: 3.1480... Val Loss: 3.1348\n", "trump e t te oeoteo ta tt aate o ae ot to a ao at e e e tea eoe o taete eoteo a at tote t oo te a t et ot tea oee\n", "Epoch: 70/120... Step: 3350... Loss: 3.1522... Val Loss: 3.1348\n", "trumptoa e t oooetttao ott ot o te ete eeeo oa o toa ae ea etot e eeto oattaeo toa eeet t a tatae a o e a to e o\n", "Epoch: 71/120... Step: 3400... Loss: 3.1846... Val Loss: 3.1349\n", "trumpt t a eeeee a ta etoetae tte eteot oee oe o ao e oteat ee atatte e ta oaet aeataoaae o oe t e e e aa aao toeao o e\n", "Epoch: 72/120... Step: 3450... Loss: 3.1672... Val Loss: 3.1350\n", "trumpetetaa tae ooa aa o e eo a oe eo e eea tatto o ao t te oeoeo e oa oa ot te e e e o oo et oaeee t oa attt ae eoo\n", "Epoch: 73/120... Step: 3500... Loss: 3.1542... Val Loss: 3.1352\n", "trump ata aaoaott e t o at teeee ataeaeea etoo aao eoa tt oa to e o o ooae a ot t e e to tet atee taeooto ae ata a teeto aa toe\n", "Epoch: 74/120... Step: 3550... Loss: 3.1629... Val Loss: 3.1353\n", "trumptaoe a o e oatt e aot otteea tt ta eaa t a eettte o e ae oeetaoote eoo e e a t ae o aae e t oeto etteteae\n", "Epoch: 75/120... Step: 3600... Loss: 3.1484... Val Loss: 3.1351\n", "trump tateao o eaoe ee a t ao oo a etet ta eeeeeeet to e o e t oe et ea a ao a a ttaettae t ta ea eottteett ao oat e ee eae \n", "Epoch: 77/120... Step: 3650... Loss: 3.1666... Val Loss: 3.1346\n", "trump e ae teaoaot t oat otteo ee e eaea taot et ooe a e e to oteo tee oeto o to t te etotoeoaoeo ta ee ootte oot a te tt t\n", "Epoch: 78/120... Step: 3700... Loss: 3.1513... Val Loss: 3.1343\n", "trumpeetae t toetee aeao oeo e aa e ao tooe oa aet a tto e tao tt a e ot o t oeo aa ae o ea tta et tet e o tt t e ee\n", "Epoch: 79/120... Step: 3750... Loss: 3.1367... Val Loss: 3.1341\n", "trump aaoeoee o ee o ee e t e t att t aeoa o eete taee eta e tate to eea e ae ette a eet aaooaaott e e ee et e toa o\n", "Epoch: 80/120... Step: 3800... Loss: 3.1630... Val Loss: 3.1340\n", "trump o e etoo te tet t oeoa tae a ee ot eo eooe tatea teeota ee aa ett e t at oeooaa otee e aa aeotaaaatata o \n", "Epoch: 81/120... Step: 3850... Loss: 3.1502... Val Loss: 3.1338\n", "trumpoe ee eo e e e eea aettee eeae e tot eat eeeoeeae o e e aaoo a o oto t e tea eoa to ot oeo e ae e oteatet eo aa\n", "Epoch: 82/120... Step: 3900... Loss: 3.1710... Val Loss: 3.1339\n", "trump oa eoe aao t t t a o a e t aeeaae aote teo et eo ote a t o tte a ttataooo ee e e at o ttae ota o oeee ae\n", "Epoch: 83/120... Step: 3950... Loss: 3.1589... Val Loss: 3.1339\n", "trumpee e ti toe te o e to eo toi e e teotte e ioiitet oo eei e e oo ii eett etoee it toeeeee eot eoit eotiie it iteeeieot \n", "Epoch: 84/120... Step: 4000... Loss: 3.1450... Val Loss: 3.1340\n", "trumpeie tit ettoii e ttoe oioottt eoeiioo eei ioo et eoeeio ei eottto oeo i eeoetot oo e eeo ei i o i ttioe e oee eoe oit o\n", "Epoch: 85/120... Step: 4050... Loss: 3.1338... Val Loss: 3.1342\n", "trumpte oe e o ee o oeii t e e itt eteiteito t eeeeoo ee o itoeiiooieot o ie oie tt e i eitt oitoi e t et to oeeooii to te\n", "Epoch: 86/120... Step: 4100... Loss: 3.1504... Val Loss: 3.1341\n", "trump e o a tt etoooa eae oteo te oata a teeaa oee eo e oea a ta atetoatet eeooae tootoo e te otaoeo t oea eaeta eea \n", "Epoch: 87/120... Step: 4150... Loss: 3.1537... Val Loss: 3.1341\n", "trumpe a eto to eoooeooae t eea e a ae t tt taaae toeoee toetta e eototeeto ooa etoe a e e a o t oeoe oaa o e a a ee e a oe a\n", "Epoch: 88/120... Step: 4200... Loss: 3.1774... Val Loss: 3.1340\n", "trumpoao ao eo o t ee to e t to oo t eeeaa t et et e t e ee et ete t taoea oo oee a oeoo t a ao e aa eo a \n", "Epoch: 89/120... Step: 4250... Loss: 3.1647... Val Loss: 3.1342\n", "trumpe atetaooeto oat eeateea tatet oa e o oe aete oe eeoeo oae a oa atea aeeaeto e aoot ao o eaoet teato ttateoto otet\n", "Epoch: 90/120... Step: 4300... Loss: 3.1451... Val Loss: 3.1343\n", "trump e et a e ooeeatoea a oeao ete o otot aea eteoot a t atte tt t et aaea o oee ott e oote eoetoe o t ooooaa a t oe e \n", "Epoch: 91/120... Step: 4350... Loss: 3.1636... Val Loss: 3.1346\n", "trump oe e eaat t e ea t e e eotet oae e o eta tet e eeeoeea e t toee t eteo oet eeoee oetooaet eeoaetaotoe a e ta aeoa \n", "Epoch: 92/120... Step: 4400... Loss: 3.1572... Val Loss: 3.1352\n", "trump ae ot e etaetoo oeat t e ee ae ttt teeooateaa eattaat aet tto ottoto te otteo eaott aae e e etottaet e teo \n", "Epoch: 93/120... Step: 4450... Loss: 3.1541... Val Loss: 3.1348\n", "trumpo ataotoeaeoaatoet eottoee ta oe e et ea o e ee aoooa aot eet e eeaoe teto eoe taao t t ea e eae aea t \n", "Epoch: 94/120... Step: 4500... Loss: 3.1472... Val Loss: 3.1347\n", "trumptaoeae eate eoet t taeaato tate eo ee a te e t oo o to teae aa ta ote eoetaat ao ottaoaa eet ea e oo o tt \n", "Epoch: 95/120... Step: 4550... Loss: 3.1518... Val Loss: 3.1347\n", "trump toe eata oe eee eo eetoo oa te toeo e et eoet ee otao o o ee to aeaa t a a tao ooe ae eato ttota ee tteoe \n", "Epoch: 96/120... Step: 4600... Loss: 3.1842... Val Loss: 3.1348\n", "trump o eeetao a tteeoaeao tt o oeet oe eeeotatte at e eee o teotta eet a t o eoe e at oe e t oote eee oa ete ot ta e taato\n", "Epoch: 97/120... Step: 4650... Loss: 3.1661... Val Loss: 3.1349\n", "trump eae oattt aa oa e o a eeoe ott oeaoetae oe e eet etee ata eteetoa eeae te o t ot t ttaaae et e ao et a t \n", "Epoch: 98/120... Step: 4700... Loss: 3.1527... Val Loss: 3.1352\n", "trumpate ao ota eooteae oa ota e e oaa ae a oe t a e eeaao e t eoo o oo o t eaeo oeeoa tt ae t eaee tea aae\n", "Epoch: 99/120... Step: 4750... Loss: 3.1631... Val Loss: 3.1353\n", "trump e te a eet eoea t ea aae o t ooooee tto e at ttet o e t ett teo ettea e eo ete t aato eoot e aet eeeo\n", "Epoch: 100/120... Step: 4800... Loss: 3.1487... Val Loss: 3.1351\n", "trumpao aaote e ato ee o o teoaaa ttt at at tt aoeta eoa ottotoeaaoeeaa totooo teeee ooooataet toooteeaea ao oeootaa tt\n", "Epoch: 102/120... Step: 4850... Loss: 3.1659... Val Loss: 3.1348\n", "trumpa att eoe aoeeetea ao oeta aaaeoaaat e t e oe t ttatet ee ateo e eee tet ote aoaa eto too t ea ee o ta t taett\n", "Epoch: 103/120... Step: 4900... Loss: 3.1512... Val Loss: 3.1344\n", "trumpeo ettttaeao a eo ao a oa oet o eteeoeoa ate oate tooeo atee t a oeeattaa e e t et t e e ee ea a ooetooto a \n", "Epoch: 104/120... Step: 4950... Loss: 3.1362... Val Loss: 3.1344\n", "trumpe ot a oa o at e ea te eee aee o eaa tt oeo otaee t t ao o o aaaaeeteotooooteeeeatotta ooe a tae eo aot e o\n", "Epoch: 105/120... Step: 5000... Loss: 3.1615... Val Loss: 3.1343\n", "trumpet otteet o ae eoea eo ao e ooeaaeoeto eee aet e e o oeaeotaaet ot eoaet aaet eaoa t ee aaetee e oeto e oeaa\n", "Epoch: 106/120... Step: 5050... Loss: 3.1488... Val Loss: 3.1341\n", "trumpeo tttat ooee ee eaatea ott oaet ee a a e ete ae et eooooao t ee t oo e e e tao ata o etoa eta tto a eo tao\n", "Epoch: 107/120... Step: 5100... Loss: 3.1705... Val Loss: 3.1340\n", "trump a a aa ee aetee oao e a teeoaoaa t o a ea teeo oeto oettaoa tt aet ot teo ooe a ete t a o e ot o oe tee ao a a\n", "Epoch: 108/120... Step: 5150... Loss: 3.1590... Val Loss: 3.1340\n", "trumpioi t toieooet i tieteo oteeee tetio ieeo oo oe i o oiee it e itot eie e i e ote iit tt ei t t eioo eioeoitoo\n", "Epoch: 109/120... Step: 5200... Loss: 3.1445... Val Loss: 3.1341\n", "trump tiit tee eeitittoe oitii eeo i ti e e ttt eoeo tooioe i eo ot e eoeoee otee te tei ei i ooot o teotoe eei\n", "Epoch: 110/120... Step: 5250... Loss: 3.1322... Val Loss: 3.1342\n", "trumpioo eoeoiit i oo i o eo eeeo teeote oto e oieoe e t eio eit e iioeot i t eoeie otteoe ioio t totoe tetotoiii t e\n", "Epoch: 111/120... Step: 5300... Loss: 3.1499... Val Loss: 3.1341\n", "trump tt a te t t to e eto te o eoetea t ataeetoeteaoto aee a tet ee a o teeea ee tea e oa o e oo te atata \n", "Epoch: 112/120... Step: 5350... Loss: 3.1544... Val Loss: 3.1342\n", "trump te t e e a eaote ote at otaat o e t t e oea toeattot eaaoa tt te ote oe teae t oa teeo e atee eeo ttotee oo\n", "Epoch: 113/120... Step: 5400... Loss: 3.1761... Val Loss: 3.1342\n", "trump aeae e at e ooa tt ot t o ae teeta t ooe oa eet ae ot ee eetto at te a tatoo eeaa o e to teo t t teoo t ea\n", "Epoch: 114/120... Step: 5450... Loss: 3.1645... Val Loss: 3.1343\n", "trumpoott e etto eeeot eoe a ot eet atoee o e ao te tao aeoo aoeaoatee e e taa ao ae o o oeo ettte t ete a e e\n", "Epoch: 115/120... Step: 5500... Loss: 3.1456... Val Loss: 3.1345\n", "trumpeteteto eo atee e a e eaa et et eo aa otteoe at aa eaetoe at aaoaeoettetet e e ae ea o o oeeeet e aoaa t a ot\n", "Epoch: 116/120... Step: 5550... Loss: 3.1638... Val Loss: 3.1346\n", "trumpeateooeota et aoa to o oeoa ee ea et te a eee a a e oe tao ee e o oaa tot aetota tt ta e tee t teoeeaaa t tto a t t\n", "Epoch: 117/120... Step: 5600... Loss: 3.1570... Val Loss: 3.1348\n", "trumpoe e oeotoa a ee etta eae tt eeo etot et t otot eeoat ootea et t a ott ee ato aa a aa ta aat a ea t ateeo att ee a\n", "Epoch: 118/120... Step: 5650... Loss: 3.1546... Val Loss: 3.1347\n", "trumpaeo o tt ttaeaoe t ttet a eoo aa eae eoe t oe a t tt aee ta aoaa oeaeaeo ao atot t etoeeeee oo ee atoot oe aoeooeoe t\n", "Epoch: 119/120... Step: 5700... Loss: 3.1468... Val Loss: 3.1348\n", "trumpeoetttt eoao e o tooeaaoo eot ttt oooaete e a eeeo eao eeo aott ae oea eoeaoet ae te ttoae o ttatota t e a eoa to o t \n", "Epoch: 120/120... Step: 5750... Loss: 3.1514... Val Loss: 3.1349\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "XwEphQu040cX" }, "source": [ "# Declaring a method to generate new text\n", "def sample(net, size, prime='The', top_k=None):\n", " \n", " if(train_on_gpu):\n", " net.cuda()\n", " else:\n", " net.cpu()\n", " \n", " net.eval() # eval mode\n", " \n", " # First off, run through the prime characters\n", " chars = [ch for ch in prime]\n", " h = net.init_hidden(1)\n", " for ch in prime:\n", " char, h = predict(net, ch, h, top_k=top_k)\n", "\n", " chars.append(char)\n", " \n", " # Now pass in the previous character and get a new one\n", " for ii in range(size):\n", " char, h = predict(net, chars[-1], h, top_k=top_k)\n", " chars.append(char)\n", "\n", " return ''.join(chars)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sA0nDYW053kZ", "outputId": "cdae73cc-5948-40fd-eb1f-756181bddd07" }, "source": [ "print(sample(net, 256, prime='hrc', top_k=5))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "hrc. who children? what is this relevant? what does at the special control? what was the present? why is the spash past for allowed by security of something to allow that people in the fund to being is the support to the mueller, and of past story and and such\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "x-fi-OcV8fUA" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }