{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "RNN_GPU",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "TZcVipxjxNBU"
      },
      "source": [
        "import json\n",
        "data =' '.join([p['text'].replace('>>','').replace('\\n',' ') for p in json.load(open('posts.json')) if 'text' in p])\n",
        "data2 = ' '.join([x for x in data.split(' ') if 'http' not in x and not x.isdigit()]).lower()\n",
        "open('posts.txt','w').write(data2)\n",
        "text = data2\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c3k-X0uaxer0"
      },
      "source": [
        "# Importing libraries\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.nn.functional as F"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ztm3f2DIxs3g"
      },
      "source": [
        "chars = tuple(set(text))\n",
        "int2char = dict(enumerate(chars))\n",
        "char2int = {ch: ii for ii, ch in int2char.items()}\n",
        "\n",
        "# Encode the text\n",
        "encoded = np.array([char2int[ch] for ch in text])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z78EhW6exzJL"
      },
      "source": [
        "# Defining method to encode one hot labels\n",
        "def one_hot_encode(arr, n_labels):\n",
        "    \n",
        "    # Initialize the the encoded array\n",
        "    one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)\n",
        "    \n",
        "    # Fill the appropriate elements with ones\n",
        "    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.\n",
        "    \n",
        "    # Finally reshape it to get back to the original array\n",
        "    one_hot = one_hot.reshape((*arr.shape, n_labels))\n",
        "    \n",
        "    return one_hot"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "itguNwWRx1zx",
        "outputId": "ee98a665-2d23-426d-c2b6-36372eeeb0c1"
      },
      "source": [
        "# Check if GPU is available\n",
        "train_on_gpu = torch.cuda.is_available()\n",
        "if(train_on_gpu):\n",
        "    print('Training on GPU!')\n",
        "else: \n",
        "    print('No GPU available, training on CPU; consider making n_epochs very small.')\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Training on GPU!\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8_45n_OYx41g"
      },
      "source": [
        "# Declaring the model\n",
        "class CharRNN(nn.Module):\n",
        "    \n",
        "    def __init__(self, tokens, n_hidden=256, n_layers=2,\n",
        "                               drop_prob=0.5, lr=0.001):\n",
        "        super().__init__()\n",
        "        self.drop_prob = drop_prob\n",
        "        self.n_layers = n_layers\n",
        "        self.n_hidden = n_hidden\n",
        "        self.lr = lr\n",
        "\n",
        "        # creating character dictionaries\n",
        "        self.chars = tokens\n",
        "        self.int2char = dict(enumerate(self.chars))\n",
        "        self.char2int = {ch: ii for ii, ch in self.int2char.items()}\n",
        "\n",
        "        self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, \n",
        "                            dropout=drop_prob, batch_first=True)\n",
        "        \n",
        "        #self.rnn = nn.RNN(len(self.chars), n_hidden, n_layers, batch_first=True)\n",
        "        self.dropout = nn.Dropout(drop_prob)\n",
        "        self.fc = nn.Linear(n_hidden, len(self.chars))\n",
        "\n",
        "\n",
        "    def forward(self, x, hidden):\n",
        "        ''' Forward pass through the network. \n",
        "            These inputs are x, and the hidden/cell state `hidden`. '''\n",
        "                \n",
        "        #get the outputs and the new hidden state from the lstm\n",
        "        r_output, hidden = self.lstm(x, hidden)\n",
        "        out = self.dropout(r_output)\n",
        "        out = out.contiguous().view(-1, self.n_hidden)\n",
        "        out = self.fc(out)\n",
        "\n",
        "        return out, hidden\n",
        "\n",
        "    def init_hidden(self, batch_size):\n",
        "        ''' Initializes hidden state '''\n",
        "        # Create two new tensors with sizes n_layers x batch_size x n_hidden,\n",
        "        # initialized to zero, for hidden state and cell state of LSTM\n",
        "        weight = next(self.parameters()).data\n",
        "        \n",
        "        if (train_on_gpu):\n",
        "            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),\n",
        "                  weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())\n",
        "        else:\n",
        "            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),\n",
        "                      weight.new(self.n_layers, batch_size, self.n_hidden).zero_())\n",
        "        \n",
        "        return hidden\n",
        "\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s5BV2p3J1m2X"
      },
      "source": [
        "# Defining method to make mini-batches for training\n",
        "def get_batches(arr, batch_size, seq_length):\n",
        "    '''Create a generator that returns batches of size\n",
        "       batch_size x seq_length from arr.\n",
        "       \n",
        "       Arguments\n",
        "       ---------\n",
        "       arr: Array you want to make batches from\n",
        "       batch_size: Batch size, the number of sequences per batch\n",
        "       seq_length: Number of encoded chars in a sequence\n",
        "    '''\n",
        "    \n",
        "    batch_size_total = batch_size * seq_length\n",
        "    # total number of batches we can make\n",
        "    n_batches = len(arr)//batch_size_total\n",
        "    \n",
        "    # Keep only enough characters to make full batches\n",
        "    arr = arr[:n_batches * batch_size_total]\n",
        "    # Reshape into batch_size rows\n",
        "    arr = arr.reshape((batch_size, -1))\n",
        "    \n",
        "    # iterate through the array, one sequence at a time\n",
        "    for n in range(0, arr.shape[1], seq_length):\n",
        "        # The features\n",
        "        x = arr[:, n:n+seq_length]\n",
        "        # The targets, shifted by one\n",
        "        y = np.zeros_like(x)\n",
        "        try:\n",
        "            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]\n",
        "        except IndexError:\n",
        "            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]\n",
        "        yield x, y"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bw9hHNMnzN08"
      },
      "source": [
        "def train(net, data, epochs=10, batch_size=10, seq_length=50, lr=0.001, clip=5, val_frac=0.1, print_every=10):\n",
        "    net.train()\n",
        "\n",
        "    opt = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    \n",
        "    # create training and validation data\n",
        "    val_idx = int(len(data)*(1-val_frac))\n",
        "    data, val_data = data[:val_idx], data[val_idx:]\n",
        "    \n",
        "    if(train_on_gpu):\n",
        "        net.cuda()\n",
        "\n",
        "    counter = 0\n",
        "    n_chars = len(net.chars)\n",
        "    for e in range(epochs):\n",
        "        # initialize hidden state\n",
        "        h = net.init_hidden(batch_size)\n",
        "        for x, y in get_batches(data, batch_size, seq_length):\n",
        "            counter += 1\n",
        "\n",
        "            # One-hot encode our data and make them Torch tensors\n",
        "            x = one_hot_encode(x, n_chars)\n",
        "            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)\n",
        "\n",
        "            if(train_on_gpu):\n",
        "                inputs, targets = inputs.cuda(), targets.cuda()\n",
        "\n",
        "            # Creating new variables for the hidden state, otherwise\n",
        "            # we'd backprop through the entire training history\n",
        "            h = tuple([each.data for each in h])\n",
        "\n",
        "            # zero accumulated gradients\n",
        "            net.zero_grad()\n",
        "            \n",
        "            output, h = net(inputs, h)\n",
        "            \n",
        "            # calculate the loss and perform backprop\n",
        "            loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
        "            loss.backward()\n",
        "            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n",
        "            nn.utils.clip_grad_norm_(net.parameters(), clip)\n",
        "            opt.step()\n",
        "\n",
        "            # loss stats\n",
        "            if counter % print_every == 0:\n",
        "                # Get validation loss\n",
        "                val_h = net.init_hidden(batch_size)\n",
        "                val_losses = []\n",
        "                net.eval()\n",
        "                for x, y in get_batches(val_data, batch_size, seq_length):\n",
        "                    # One-hot encode our data and make them Torch tensors\n",
        "                    x = one_hot_encode(x, n_chars)\n",
        "                    x, y = torch.from_numpy(x), torch.from_numpy(y)\n",
        "                    \n",
        "                    # Creating new variables for the hidden state, otherwise\n",
        "                    # we'd backprop through the entire training history\n",
        "                    val_h = tuple([each.data for each in val_h])\n",
        "                    \n",
        "                    inputs, targets = x, y\n",
        "                    if(train_on_gpu):\n",
        "                        inputs, targets = inputs.cuda(), targets.cuda()\n",
        "\n",
        "                    output, val_h = net(inputs, val_h)\n",
        "                    val_loss = criterion(output, targets.view(batch_size*seq_length).long())\n",
        "                \n",
        "                    val_losses.append(val_loss.item())\n",
        "\n",
        "                print(sample(net, 128, prime='trump', top_k=5))                \n",
        "                net.train() # reset to train mode after iterationg through validation data\n",
        "                \n",
        "                print(\"Epoch: {}/{}...\".format(e+1, epochs),\n",
        "                      \"Step: {}...\".format(counter),\n",
        "                      \"Loss: {:.4f}...\".format(loss.item()),\n",
        "                      \"Val Loss: {:.4f}\".format(np.mean(val_losses)))\n",
        "                \n",
        "            \n",
        "           "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FF0uV9AK1lny"
      },
      "source": [
        "def predict(net, char, h=None, top_k=None):\n",
        "        ''' Given a character, predict the next character.\n",
        "            Returns the predicted character and the hidden state.\n",
        "        '''\n",
        "        \n",
        "        # tensor inputs\n",
        "        x = np.array([[net.char2int[char]]])\n",
        "        x = one_hot_encode(x, len(net.chars))\n",
        "        inputs = torch.from_numpy(x)\n",
        "        \n",
        "        if(train_on_gpu):\n",
        "            inputs = inputs.cuda()\n",
        "        \n",
        "        # detach hidden state from history\n",
        "        h = tuple([each.data for each in h])\n",
        "        # get the output of the model\n",
        "        out, h = net(inputs, h)\n",
        "\n",
        "        # get the character probabilities\n",
        "        p = F.softmax(out, dim=1).data\n",
        "        if(train_on_gpu):\n",
        "            p = p.cpu() # move to cpu\n",
        "        \n",
        "        # get top characters\n",
        "        if top_k is None:\n",
        "            top_ch = np.arange(len(net.chars))\n",
        "        else:\n",
        "            p, top_ch = p.topk(top_k)\n",
        "            top_ch = top_ch.numpy().squeeze()\n",
        "        \n",
        "        # select the likely next character with some element of randomness\n",
        "        p = p.numpy().squeeze()\n",
        "        char = np.random.choice(top_ch, p=p/p.sum())\n",
        "        \n",
        "        # return the encoded value of the predicted char and the hidden state\n",
        "        return net.int2char[char], h"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uwSXo6G10A2x",
        "outputId": "4df6c80c-5cef-42b1-a5f3-32e043581004"
      },
      "source": [
        "n_hidden=128\n",
        "n_layers=8\n",
        "\n",
        "net = CharRNN(chars, n_hidden, n_layers)\n",
        "print(net)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "CharRNN(\n",
            "  (lstm): LSTM(83, 128, num_layers=8, batch_first=True, dropout=0.5)\n",
            "  (dropout): Dropout(p=0.5, inplace=False)\n",
            "  (fc): Linear(in_features=128, out_features=83, bias=True)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-w49xZ9k0NZI",
        "outputId": "aa673b7e-9b7a-40ec-9702-2dca4668a779"
      },
      "source": [
        "batch_size = 128\n",
        "seq_length = 128\n",
        "n_epochs = 120 # start smaller if you are just testing initial behavior\n",
        "\n",
        "# train the model\n",
        "train(net, encoded, epochs=n_epochs, batch_size=batch_size, seq_length=seq_length, lr=0.001, print_every=50)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "trump ee ateeteoea e e e eaeaa a tto et  oot   e oo ea     tta  e aooae ea     e  taae at    ae  eee atet aa eoea o teo       oa   oto\n",
            "Epoch: 2/120... Step: 50... Loss: 3.2235... Val Loss: 3.1450\n",
            "trumpt  t toea e     eetae otta t a e a et et   o     eo toe tat   t ee t  o   ae ateteta tao a  otaetaooe o ae to ott o t   tt  t eoo\n",
            "Epoch: 3/120... Step: 100... Loss: 3.1802... Val Loss: 3.1389\n",
            "trumptoeae e   aeoaete eoooea t e e oaoaeae eoeeoeteoaeae aa   ee oee o aaa e t ata ote tt e  oeatea    et oae e  otao ao oet eo  ao o\n",
            "Epoch: 4/120... Step: 150... Loss: 3.1537... Val Loss: 3.1371\n",
            "trumpato  aeoaoa  ea et oeotoo     a to aet  o  oe aa   t a  eeatettet aa aa  oateeat etoaet oe   tta  a oee   e  ao taa e e t e  teao\n",
            "Epoch: 5/120... Step: 200... Loss: 3.1738... Val Loss: 3.1365\n",
            "trumptt   e t  aeoe  te et  oa  t  ee a   a  e   o eaoeaoo   oot aao   at e  a tteeota   otaoaoetet  a taoeaet a otetaeteto et e oatea\n",
            "Epoch: 6/120... Step: 250... Loss: 3.1604... Val Loss: 3.1361\n",
            "trump to e   a t ot   tota tt eeooa oot  eoe taee e eeeae   ttao  t     a e taee  aetootea  etaatate t  ee eot oea  ta o to te  et  e \n",
            "Epoch: 7/120... Step: 300... Loss: 3.1800... Val Loss: 3.1364\n",
            "trump  ie o iiet   ot  eo oietit  o e iioioe i   etteoeiiteoioete o ei oiie t   itieteettto eeote   it toie  teooeeet teei eteieeeti  \n",
            "Epoch: 8/120... Step: 350... Loss: 3.1663... Val Loss: 3.1365\n",
            "trumpe ttoo o tit  tiiiie iiooeoi o ioet  o  ee oe o eot   et  otitioet oo oe ttei     eittet ot otiieee eoito t      eoi t i  ee  ei \n",
            "Epoch: 9/120... Step: 400... Loss: 3.1512... Val Loss: 3.1363\n",
            "trumpoe    oeoie   ioeeeti eo o   eo ee t e o t teeteott   e tteii ei ie t  ooteeo  teie titooteeo eo eoeio o   t et toitoiioo o e iie\n",
            "Epoch: 10/120... Step: 450... Loss: 3.1376... Val Loss: 3.1361\n",
            "trumpetet eo oa tt  o  ta   t o o  t a  eo oaooaeetata a  a  oa t aet  o eae aa attootteaoa  a    aa tetot e teea  a  oeea   ae  e aa \n",
            "Epoch: 11/120... Step: 500... Loss: 3.1536... Val Loss: 3.1357\n",
            "trumpe et  o ooaoe    tat  aa tttoate e o oo ooo aa  ot   e o  ateo  a t a  eot e ateat too e t ot at oeo oe ot  t etooe oatteee  at  \n",
            "Epoch: 12/120... Step: 550... Loss: 3.1606... Val Loss: 3.1358\n",
            "trumpeet    ate eeataet atoe to oataeeaa eoe  oe t   a et   eetoa  t e  o t t     ateteee oeea  tt eeo a teoa eaeo oe   etto  ee  oe t\n",
            "Epoch: 13/120... Step: 600... Loss: 3.1822... Val Loss: 3.1359\n",
            "trumpotet e a aoe    eooa ata tee aeeoo oaa  at   atete oo e  o eoooeta eat  ooae o oe  t  ea  o  to  ete aa eeea  e e  oe e atatattae\n",
            "Epoch: 14/120... Step: 650... Loss: 3.1701... Val Loss: 3.1363\n",
            "trump  eet t  tae aae e oo eaeoaat eo oeotoe  otote eaee eet  tt   aeoo tettto t t eo  tet o   ooe     tee oeet oeae oeo a eeateoeooo \n",
            "Epoch: 15/120... Step: 700... Loss: 3.1483... Val Loss: 3.1365\n",
            "trump  oo eoto t  e   t eee eto     aoetaaee  e  ooet aot etttte a et a  o ttto e aea aeot  aeeee t  teaet ea a  ta      eea  e ttoe o\n",
            "Epoch: 16/120... Step: 750... Loss: 3.1702... Val Loss: 3.1365\n",
            "trump    ee  tetta atet  ata    o  o  ate e  etae ta o o too  et ato eoo et   a  t o   toetea  at ao tao oat  o ee  a aeoaaate  taeeo \n",
            "Epoch: 17/120... Step: 800... Loss: 3.1603... Val Loss: 3.1365\n",
            "trumpe  o oo o eteo  aoe  a e et ott   oe eta  t e e  tot ooot   ota ae aoaa aao   aoetaott ot t    et too oeo    ooea eeoe o tte ot e\n",
            "Epoch: 18/120... Step: 850... Loss: 3.1584... Val Loss: 3.1362\n",
            "trumpo aato aeeatooetat aae ooeeto t ate ta t     o ee att oao   e eeot eeteeoeta ot aooe o  o eo t t  e   aeoaee o   at oo ta aeooto \n",
            "Epoch: 19/120... Step: 900... Loss: 3.1507... Val Loss: 3.1360\n",
            "trump   e ee e t   aeaa  ta aeeat a aaao   oo  a  eotaaee  o  teeetta ttaea to  otoeoetat ttoee a   tta a otee ee eo t ae oeot to tt t\n",
            "Epoch: 20/120... Step: 950... Loss: 3.1540... Val Loss: 3.1358\n",
            "trumpeee eteo  teott ao  tt oateett a  ate oo  tatt et  ea  e  e e o  et  a t tae  to oa t e aoa  aeaoe e ao      tteeoo  at to a  to \n",
            "Epoch: 21/120... Step: 1000... Loss: 3.1886... Val Loss: 3.1358\n",
            "trump     ao   ot eeet  eeo   ao oe  etotota   to aa at   tatooe a  t tt  ee ett ett eoeaet ee    t    t eete aoea at o  a ttoeoet t e\n",
            "Epoch: 22/120... Step: 1050... Loss: 3.1692... Val Loss: 3.1360\n",
            "trump  ae   ae a et   o teto oe oooo t e  tottoaoe oa ae   te    oe t e etoee a  oe oeo eo a eto teoatea  t oa          a t  toet ato \n",
            "Epoch: 23/120... Step: 1100... Loss: 3.1559... Val Loss: 3.1360\n",
            "trump aoea eo e eoo ae tt   e e  a taeaa  eeatt eea  ttta ee o  etao e t t a  eee e e  oeeotteee ta t e e e t  aoeoatttaoeteaaoa oaeo \n",
            "Epoch: 24/120... Step: 1150... Loss: 3.1655... Val Loss: 3.1358\n",
            "trumpe  oaoe e oo e ooa  eaoeoe ooooaot   e e ta oaee e tatte ae  t  o   eeo  taa oeo o  teete aotett  eeetat oteao at  o       o et  \n",
            "Epoch: 25/120... Step: 1200... Loss: 3.1527... Val Loss: 3.1354\n",
            "trumpoat eoo  tt  eaaaoee     a  ao eao  aa ee  o t   te  aea e   t oao a  t eaa o a e t  t at  eo at a o   o ea ooaao  a ea t  oo ata\n",
            "Epoch: 27/120... Step: 1250... Loss: 3.1708... Val Loss: 3.1348\n",
            "trumpe a ea a a eooae oao ata o  toteaoe  e tattea eteae  e  ta o   oa e    t  aet oeet  aooaottat e  eeoao toaa ee  t at e  a eo tooe\n",
            "Epoch: 28/120... Step: 1300... Loss: 3.1545... Val Loss: 3.1343\n",
            "trump  tee e oee  eeaat  eo a e ta oe te  eaa  eeeoea e    t oeaotaett eae eao ot o t a t te    etaaee aeaa oe a t  ae      e   teoa e\n",
            "Epoch: 29/120... Step: 1350... Loss: 3.1381... Val Loss: 3.1341\n",
            "trump oo  a e e tt  oeaa  t tteaet e  ea e e   ot  oat et   oao e  o  oa  tae atato eoo ee t oteteoeto  aoeeea e toe a ot t o aetaeo  \n",
            "Epoch: 30/120... Step: 1400... Loss: 3.1656... Val Loss: 3.1341\n",
            "trumpae  toa  e eaoeo   tttettettt e t ae ao tate  toooo     otea otaeeotaeot  tooa eeoeottotte  oo     ae a  t  o e o oa ee tatteoaet\n",
            "Epoch: 31/120... Step: 1450... Loss: 3.1517... Val Loss: 3.1341\n",
            "trumpe e  iiito ot t   i  otoe o   et  ite t  o otte i o etoioeotiie   ie   t eeeieoottt    ie  eoooeeo e   t e i  o   t   eet ite i  \n",
            "Epoch: 32/120... Step: 1500... Loss: 3.1727... Val Loss: 3.1345\n",
            "trump tooi ttteoetoo  t i eot  i iieti eioit i tt ieiio i tt te t      i itoee i ott oeee  t  teoet ti e e io  eei i t  ie teeiiiee   \n",
            "Epoch: 33/120... Step: 1550... Loss: 3.1610... Val Loss: 3.1345\n",
            "trumpiot  oi eeoeetiiooi  e t    te eto oie t  titeo e toe t o   ete    oetetiie   eooeootee to oe te it o oeo etiooieeoet ot oeieote \n",
            "Epoch: 34/120... Step: 1600... Loss: 3.1476... Val Loss: 3.1346\n",
            "trumpioeto  e ite o ee io eeto  eoi  tioe toot et  eeetie    iio toottooeittt e  oto  t tteeoe i oee etotoitttie  e  otioie   eoe ie e\n",
            "Epoch: 35/120... Step: 1650... Loss: 3.1355... Val Loss: 3.1344\n",
            "trumpa   aea  tte    t     ta   ett aoeee   eoae eoeae tee aoete toteo aee  o taeeoett t ea eto eot  o    a  te e aa o o  e     o e ee\n",
            "Epoch: 36/120... Step: 1700... Loss: 3.1517... Val Loss: 3.1342\n",
            "trump  aeata      taot   a eaa eeaae  te ta t tttteeoee at  ote   ttt  a ttt  ee a  teateteao  t e at aet    toeaeee  a  tet  a ee   e\n",
            "Epoch: 37/120... Step: 1750... Loss: 3.1565... Val Loss: 3.1340\n",
            "trumpt  a   a  etoo e  e   o    a  ato t  o  et  e  t e o   ee   o    aeeae eaoatte a  et  oo a oeae     o t e oo o eaa   t  a ta   ee\n",
            "Epoch: 38/120... Step: 1800... Loss: 3.1789... Val Loss: 3.1341\n",
            "trump oett  t   attee  o  eoo     ooe a tt e  et e  o   ao  t e aaoee  t o to atea     eo  e   aoto e  toaoo ate  o ato aoo etto  taa \n",
            "Epoch: 39/120... Step: 1850... Loss: 3.1673... Val Loss: 3.1346\n",
            "trump ee o o e t eo t  aae taa e    at   o aa e ttaea etoeaea tea aateaaaa   o te oaote a t ee  etae e   to a   e eae e e oo  e tae  a\n",
            "Epoch: 40/120... Step: 1900... Loss: 3.1481... Val Loss: 3.1349\n",
            "trumpaot  a at o otoete e a teoteaee    oo e aeet t atteott   ott o t   ttt ee   o  ee oto   tte  aoe et e oe tea tt a aae eo  ttta t \n",
            "Epoch: 41/120... Step: 1950... Loss: 3.1659... Val Loss: 3.1349\n",
            "trumptt oetao   tt   oata  ae ea   a  totteoett  a  tt a e ee   ea tt  t tt   eate     eaotttt et a eot e at  tt etaae e a  a ea  teo \n",
            "Epoch: 42/120... Step: 2000... Loss: 3.1579... Val Loss: 3.1349\n",
            "trumpeee  ate t e te   et e aao e e att o  eteee a ee   o  ee  et attet  ao at e t aea att   t e oea o t  ae   e e a o  ae oo  eeeo to\n",
            "Epoch: 43/120... Step: 2050... Loss: 3.1562... Val Loss: 3.1348\n",
            "trumpteet  atoa oe teoea oet e eet  o  oo  a teaaeo e ee  oeo       ottataot eet    eee ettott t  toe ooao ata  t ooat  eoaeee  oo ete\n",
            "Epoch: 44/120... Step: 2100... Loss: 3.1496... Val Loss: 3.1351\n",
            "trumpo  teaet ea eottoaaea t tooeoetto   ttte ot eetete ao  ooe et     o t   ooaee oe eea e toetttte e eate a  atea o a ea o ae ooeeaa\n",
            "Epoch: 45/120... Step: 2150... Loss: 3.1527... Val Loss: 3.1350\n",
            "trumpeootetoo a e ete   o  e aott eaaeo e  e  eeo     o ao t  eaotea  to teato ea  eot tt t to e o t  oo  ea t  ottt  eettao ta t  ao \n",
            "Epoch: 46/120... Step: 2200... Loss: 3.1851... Val Loss: 3.1351\n",
            "trump  ea eot e a    te o  aat  etet ta tt  eota  oaoa eetaaeto ooo eet te   eeaoa taae  oeoaat o e aeeate tteo taao  oae o o  aet  a \n",
            "Epoch: 47/120... Step: 2250... Loss: 3.1678... Val Loss: 3.1353\n",
            "trumpteaoat o   t e  a  etttt e   oe   et e  t e aaaoeote  e aeeo e taaa   a  tteaoto taoate e ato oe   ao te  ae o oa  oaa  eto tt at\n",
            "Epoch: 48/120... Step: 2300... Loss: 3.1536... Val Loss: 3.1355\n",
            "trumpoe  e  oee e oeaoee eao  attaeaeoa  e  aa totea t aaoe  tatteoaett aaa  oe  etoa oa ee aoooe tte  etot ooo oo ta eo  eetetott ea \n",
            "Epoch: 49/120... Step: 2350... Loss: 3.1660... Val Loss: 3.1355\n",
            "trumpe   aeoeeoe   tta toet oeo eoaae oote o  et   aoe  aaeo e a  o at  te    eto oaaee t   t teata t oeatooae aa  eo  e oaeea  oeaea \n",
            "Epoch: 50/120... Step: 2400... Loss: 3.1500... Val Loss: 3.1351\n",
            "trump e eto  e   te   ttaeea ote  ae  e te   ototooaooo  a t     t  toe toat e eo e eoaa oa  o  oaooa taot tt te oot eea ate  eo o to \n",
            "Epoch: 52/120... Step: 2450... Loss: 3.1682... Val Loss: 3.1346\n",
            "trumpaet a  oatt  eee a  eo ee  aaaea eota  oat etea a  te ta aatetea o ee a toee  e e  e  taete et teooea ooteeota eeeoeae   aa o ee \n",
            "Epoch: 53/120... Step: 2500... Loss: 3.1531... Val Loss: 3.1342\n",
            "trump     t aeeaete  eoo o   ao a eeaeoaoeo ee t ttteeaa o ooa e aot tot t t  et oea  t ooota aete  etaae aeoaooo  etae  a t e teoee e\n",
            "Epoch: 54/120... Step: 2550... Loss: 3.1383... Val Loss: 3.1341\n",
            "trumpo eo   t at o etoae  e    tte  ee aet et  ee  atoo  tteto ae   e  ao  aaooteo  eteoa  o taeeot   tatt  e  tt aeo t tt  ee  t a ee\n",
            "Epoch: 55/120... Step: 2600... Loss: 3.1663... Val Loss: 3.1340\n",
            "trumpt   ttooeette    a t ootatea   ttoeoeet eeatoa  ee o tttoteat e ot aooo ttt  eteaoto   toat etee  o e  oet tt   a    te etaee to \n",
            "Epoch: 56/120... Step: 2650... Loss: 3.1517... Val Loss: 3.1338\n",
            "trumpa    t aaeo  a te a etotoatetettaete attt otoetee aeoeo  t e    e   a a  t   ot t ateo o o aa    atta eaeat eatotooate a e o t et\n",
            "Epoch: 57/120... Step: 2700... Loss: 3.1727... Val Loss: 3.1340\n",
            "trumpooi eiet teote io  ie i   to ee  ot o it   t ieo   ot i eeoiteei o  e  t e etit ieot teitioi tttet  o o  toet etoet e ee e   ei  \n",
            "Epoch: 58/120... Step: 2750... Loss: 3.1597... Val Loss: 3.1341\n",
            "trump tootto o e  toe  t t ooeett eotott  ie  i oteteeeoo et o         etoteoiito te  eo   oei  te o   i ooeoee i iti   e iieei  e    \n",
            "Epoch: 59/120... Step: 2800... Loss: 3.1460... Val Loss: 3.1343\n",
            "trumpoeoote  i   eet   eoeeet eeoe etet   tee   etett  ieeeeo o toieoie e et o ttetioeto ee oo oii o ii   tt  i ie  eteot  eo ot it   \n",
            "Epoch: 60/120... Step: 2850... Loss: 3.1334... Val Loss: 3.1342\n",
            "trumpet i t ee  oi ee tooi   ee oieooi ee et t oo t o o  et ioitteoeii     it ee eeieeei  e eio eeieooo eee oie  e eeeoeiii  ttiitooi \n",
            "Epoch: 61/120... Step: 2900... Loss: 3.1509... Val Loss: 3.1341\n",
            "trumpao to et eoeo aee oe   tet tt ooa e tt  oeete ot e    o oo aaea oeaoteeea eeeo  etteeo t eaoeteeeo  oee et otet a  tao  ao eo oo \n",
            "Epoch: 62/120... Step: 2950... Loss: 3.1553... Val Loss: 3.1341\n",
            "trump   oete   o e e eeotaoea eot    e  t ao e o teea a  o oea oete aeo  eae e t eete  aeaoe  e  o aa  eo  oet  eae  oeo        tt tao\n",
            "Epoch: 63/120... Step: 3000... Loss: 3.1763... Val Loss: 3.1339\n",
            "trump   ee e   e a o a t to ete   oa    aeoe  e ta eteae    e aaett eaa     tootaeeaaa tt t    o eoeeete ooa e atotet  ea te t aee    \n",
            "Epoch: 64/120... Step: 3050... Loss: 3.1661... Val Loss: 3.1342\n",
            "trumpaeat e   t  eo o  e  ooeteaaee a  oatae  aa toaaaoa tettoe  ota  o tteeeaa t te o oo atot atota t ette e a e a  e  eoeeeoaeee ote\n",
            "Epoch: 65/120... Step: 3100... Loss: 3.1453... Val Loss: 3.1344\n",
            "trumptt te ea eteeette ee  ete e o  oet oo    tea ao to        att  ete e   oe e oeot a otea aeot aae e e   a aa ee        eeaatooeoa \n",
            "Epoch: 66/120... Step: 3150... Loss: 3.1649... Val Loss: 3.1347\n",
            "trump  ee t   o   to eeoo o to oe tetteat a ta   oee  oe  too taet a  e etoaee o e eoeo  otteae t tea  e  t te  eaatteo  oaaaeeate ee \n",
            "Epoch: 67/120... Step: 3200... Loss: 3.1583... Val Loss: 3.1348\n",
            "trump aeteo oa o t ea ettetoa ao  et  aaea ot  e ate te  ea eeoa oee ee eao oet  a  aote eoeet   ataaeee  o ttta  t e aaea   a e oea t\n",
            "Epoch: 68/120... Step: 3250... Loss: 3.1558... Val Loss: 3.1347\n",
            "trump e  o    e   etot  a  et  t e    aoe  o aaeo  te a at aaee  ee e   aot to aeaaaeea t e teaaoao te    oaoae  tetet    eaot tt  ee \n",
            "Epoch: 69/120... Step: 3300... Loss: 3.1480... Val Loss: 3.1348\n",
            "trump e t te  oeoteo  ta    tt aate o  ae  ot   to a    ao at e e e tea eoe o taete eoteo  a at  tote t oo te     a    t et ot tea oee\n",
            "Epoch: 70/120... Step: 3350... Loss: 3.1522... Val Loss: 3.1348\n",
            "trumptoa e  t  oooetttao    ott  ot  o  te   ete eeeo   oa  o toa ae ea  etot e eeto  oattaeo toa eeet     t a  tatae a o e a  to e  o\n",
            "Epoch: 71/120... Step: 3400... Loss: 3.1846... Val Loss: 3.1349\n",
            "trumpt t a eeeee a ta  etoetae  tte eteot   oee oe o ao e oteat ee atatte   e ta oaet  aeataoaae o  oe t   e   e  e aa  aao toeao  o e\n",
            "Epoch: 72/120... Step: 3450... Loss: 3.1672... Val Loss: 3.1350\n",
            "trumpetetaa tae   ooa  aa o e eo  a oe eo e  eea tatto o ao t te oeoeo  e oa oa ot  te e e  e  o    oo    et   oaeee  t oa attt ae eoo\n",
            "Epoch: 73/120... Step: 3500... Loss: 3.1542... Val Loss: 3.1352\n",
            "trump    ata aaoaott e t o at  teeee ataeaeea etoo aao eoa tt oa to e o o  ooae a ot  t e e  to tet atee taeooto ae ata a teeto aa toe\n",
            "Epoch: 74/120... Step: 3550... Loss: 3.1629... Val Loss: 3.1353\n",
            "trumptaoe   a o e   oatt  e   aot  otteea  tt ta eaa  t a eettte  o e    ae  oeetaoote eoo e e a t ae       o aae e   t  oeto etteteae\n",
            "Epoch: 75/120... Step: 3600... Loss: 3.1484... Val Loss: 3.1351\n",
            "trump tateao o   eaoe ee  a t ao oo  a etet ta  eeeeeeet to e o e t oe  et ea a ao  a a ttaettae  t ta ea   eottteett ao oat e ee eae \n",
            "Epoch: 77/120... Step: 3650... Loss: 3.1666... Val Loss: 3.1346\n",
            "trump    e ae teaoaot  t  oat otteo ee e eaea  taot  et ooe a e   e  to oteo tee oeto o to t  te etotoeoaoeo ta ee ootte oot a te tt t\n",
            "Epoch: 78/120... Step: 3700... Loss: 3.1513... Val Loss: 3.1343\n",
            "trumpeetae t toetee aeao oeo  e aa e ao  tooe oa aet  a tto  e tao    tt    a  e ot o t oeo aa  ae o ea   tta et tet  e o tt t    e ee\n",
            "Epoch: 79/120... Step: 3750... Loss: 3.1367... Val Loss: 3.1341\n",
            "trump aaoeoee o  ee o ee   e t e t att  t   aeoa o  eete    taee  eta e tate to eea  e ae ette   a eet  aaooaaott e e  ee et e  toa  o\n",
            "Epoch: 80/120... Step: 3800... Loss: 3.1630... Val Loss: 3.1340\n",
            "trump o  e  etoo te      tet    t oeoa tae  a ee ot eo eooe   tatea  teeota  ee aa   ett e t   at oeooaa    otee e aa aeotaaaatata  o \n",
            "Epoch: 81/120... Step: 3850... Loss: 3.1502... Val Loss: 3.1338\n",
            "trumpoe ee eo e   e e  eea   aettee eeae  e  tot eat eeeoeeae o  e e aaoo a o oto  t  e    tea eoa to ot    oeo e ae e oteatet   eo aa\n",
            "Epoch: 82/120... Step: 3900... Loss: 3.1710... Val Loss: 3.1339\n",
            "trump  oa eoe aao  t t    t a  o  a e   t      aeeaae  aote  teo et eo ote a t    o tte a ttataooo ee  e e  at o  ttae ota  o oeee  ae\n",
            "Epoch: 83/120... Step: 3950... Loss: 3.1589... Val Loss: 3.1339\n",
            "trumpee e ti  toe  te o e to eo   toi e e  teotte  e  ioiitet   oo eei e e  oo ii eett etoee  it toeeeee eot eoit eotiie it iteeeieot \n",
            "Epoch: 84/120... Step: 4000... Loss: 3.1450... Val Loss: 3.1340\n",
            "trumpeie tit ettoii  e ttoe  oioottt eoeiioo  eei   ioo et eoeeio ei eottto oeo i  eeoetot oo e  eeo ei i o i ttioe  e oee   eoe oit o\n",
            "Epoch: 85/120... Step: 4050... Loss: 3.1338... Val Loss: 3.1342\n",
            "trumpte oe   e   o   ee o oeii t  e e itt   eteiteito  t  eeeeoo  ee o itoeiiooieot o ie oie tt e i eitt oitoi e t et to oeeooii to te\n",
            "Epoch: 86/120... Step: 4100... Loss: 3.1504... Val Loss: 3.1341\n",
            "trump    e o a tt etoooa  eae    oteo te oata a teeaa oee  eo e  oea a ta atetoatet eeooae   tootoo e te otaoeo    t   oea eaeta  eea \n",
            "Epoch: 87/120... Step: 4150... Loss: 3.1537... Val Loss: 3.1341\n",
            "trumpe a eto to  eoooeooae  t eea e a  ae t  tt  taaae toeoee toetta e eototeeto ooa etoe a e e  a  o t oeoe oaa o  e a a ee e a  oe a\n",
            "Epoch: 88/120... Step: 4200... Loss: 3.1774... Val Loss: 3.1340\n",
            "trumpoao  ao  eo o t   ee  to  e t to oo   t eeeaa     t  et  et e t  e ee et ete  t     taoea oo    oee a  oeoo t a   ao e  aa eo  a \n",
            "Epoch: 89/120... Step: 4250... Loss: 3.1647... Val Loss: 3.1342\n",
            "trumpe atetaooeto oat eeateea  tatet oa e o   oe aete oe eeoeo oae a      oa atea  aeeaeto  e  aoot   ao  o eaoet teato  ttateoto otet\n",
            "Epoch: 90/120... Step: 4300... Loss: 3.1451... Val Loss: 3.1343\n",
            "trump e    et  a e    ooeeatoea a oeao ete o otot aea eteoot  a  t atte  tt  t et aaea o oee ott  e oote eoetoe o t  ooooaa a t  oe e \n",
            "Epoch: 91/120... Step: 4350... Loss: 3.1636... Val Loss: 3.1346\n",
            "trump    oe  e eaat t e ea  t e e eotet  oae e o eta    tet e    eeeoeea e t toee t eteo oet eeoee oetooaet  eeoaetaotoe a e ta  aeoa \n",
            "Epoch: 92/120... Step: 4400... Loss: 3.1572... Val Loss: 3.1352\n",
            "trump   ae  ot  e    etaetoo  oeat   t  e ee ae ttt teeooateaa eattaat    aet    tto  ottoto  te otteo eaott aae e  e etottaet e  teo \n",
            "Epoch: 93/120... Step: 4450... Loss: 3.1541... Val Loss: 3.1348\n",
            "trumpo ataotoeaeoaatoet eottoee   ta oe  e    et ea o      e ee  aoooa   aot   eet e eeaoe  teto  eoe taao  t t ea  e  eae  aea     t \n",
            "Epoch: 94/120... Step: 4500... Loss: 3.1472... Val Loss: 3.1347\n",
            "trumptaoeae  eate eoet t    taeaato tate   eo ee a  te e   t  oo  o   to   teae   aa ta ote eoetaat ao ottaoaa eet ea e oo o     tt   \n",
            "Epoch: 95/120... Step: 4550... Loss: 3.1518... Val Loss: 3.1347\n",
            "trump       toe eata oe  eee    eo  eetoo  oa te toeo  e et eoet ee   otao o o ee to aeaa t a a   tao ooe ae  eato  ttota ee    tteoe \n",
            "Epoch: 96/120... Step: 4600... Loss: 3.1842... Val Loss: 3.1348\n",
            "trump  o eeetao a tteeoaeao tt o oeet oe eeeotatte at e eee o  teotta eet a t  o eoe e at oe e   t oote eee oa   ete   ot ta   e taato\n",
            "Epoch: 97/120... Step: 4650... Loss: 3.1661... Val Loss: 3.1349\n",
            "trump   eae oattt aa oa e   o  a eeoe  ott  oeaoetae    oe e  eet etee  ata   eteetoa  eeae te o   t   ot t  ttaaae et    e ao et a t \n",
            "Epoch: 98/120... Step: 4700... Loss: 3.1527... Val Loss: 3.1352\n",
            "trumpate ao   ota eooteae  oa  ota  e e oaa ae  a    oe t a e  eeaao e    t   eoo o  oo o   t eaeo  oeeoa    tt     ae t eaee  tea aae\n",
            "Epoch: 99/120... Step: 4750... Loss: 3.1631... Val Loss: 3.1353\n",
            "trump  e te   a eet eoea t  ea aae   o t  ooooee tto e  at   ttet o e t  ett teo     ettea e   eo ete t    aato eoot     e  aet   eeeo\n",
            "Epoch: 100/120... Step: 4800... Loss: 3.1487... Val Loss: 3.1351\n",
            "trumpao aaote  e ato ee    o    o    teoaaa ttt at at tt aoeta eoa ottotoeaaoeeaa   totooo   teeee  ooooataet toooteeaea ao oeootaa tt\n",
            "Epoch: 102/120... Step: 4850... Loss: 3.1659... Val Loss: 3.1348\n",
            "trumpa att  eoe aoeeetea ao oeta aaaeoaaat e t e    oe t  ttatet  ee ateo e   eee tet ote aoaa     eto  too   t  ea ee o    ta t taett\n",
            "Epoch: 103/120... Step: 4900... Loss: 3.1512... Val Loss: 3.1344\n",
            "trumpeo   ettttaeao  a eo   ao a oa  oet  o eteeoeoa ate oate   tooeo atee t a oeeattaa    e e   t  et t  e   e ee  ea   a ooetooto a \n",
            "Epoch: 104/120... Step: 4950... Loss: 3.1362... Val Loss: 3.1344\n",
            "trumpe ot a  oa  o     at e ea te    eee aee   o eaa tt  oeo otaee  t  t    ao  o o aaaaeeteotooooteeeeatotta  ooe  a   tae eo aot e o\n",
            "Epoch: 105/120... Step: 5000... Loss: 3.1615... Val Loss: 3.1343\n",
            "trumpet otteet   o  ae eoea eo ao   e    ooeaaeoeto eee   aet     e   e    o oeaeotaaet ot eoaet aaet eaoa t  ee aaetee e oeto  e oeaa\n",
            "Epoch: 106/120... Step: 5050... Loss: 3.1488... Val Loss: 3.1341\n",
            "trumpeo    tttat ooee ee  eaatea   ott  oaet ee a a e  ete ae et eooooao t ee   t oo  e e      e  tao ata o    etoa eta tto   a eo tao\n",
            "Epoch: 107/120... Step: 5100... Loss: 3.1705... Val Loss: 3.1340\n",
            "trump a a aa ee aetee oao   e   a teeoaoaa t o a ea   teeo  oeto oettaoa  tt  aet  ot teo ooe   a ete t  a o e ot  o oe    tee ao a  a\n",
            "Epoch: 108/120... Step: 5150... Loss: 3.1590... Val Loss: 3.1340\n",
            "trumpioi   t toieooet i tieteo oteeee    tetio ieeo  oo oe i o     oiee it e  itot eie  e i e ote    iit tt ei  t  t eioo    eioeoitoo\n",
            "Epoch: 109/120... Step: 5200... Loss: 3.1445... Val Loss: 3.1341\n",
            "trump   tiit   tee eeitittoe oitii  eeo    i    ti e  e  ttt eoeo tooioe i  eo   ot e eoeoee   otee te  tei ei i ooot  o teotoe    eei\n",
            "Epoch: 110/120... Step: 5250... Loss: 3.1322... Val Loss: 3.1342\n",
            "trumpioo  eoeoiit i  oo    i o eo eeeo teeote oto   e oieoe e  t eio    eit  e iioeot i   t eoeie otteoe  ioio t totoe  tetotoiii t  e\n",
            "Epoch: 111/120... Step: 5300... Loss: 3.1499... Val Loss: 3.1341\n",
            "trump  tt   a te   t t to e eto      te o eoetea t ataeetoeteaoto    aee   a tet ee    a  o teeea  ee tea e  oa o  e     oo te atata  \n",
            "Epoch: 112/120... Step: 5350... Loss: 3.1544... Val Loss: 3.1342\n",
            "trump  te t e  e a eaote ote at  otaat  o e t t e oea  toeattot   eaaoa tt  te ote oe teae t   oa teeo          e  atee  eeo ttotee oo\n",
            "Epoch: 113/120... Step: 5400... Loss: 3.1761... Val Loss: 3.1342\n",
            "trump aeae e  at e ooa tt   ot t o ae    teeta t ooe   oa    eet ae ot ee eetto   at te a   tatoo eeaa o e to teo t t teoo t        ea\n",
            "Epoch: 114/120... Step: 5450... Loss: 3.1645... Val Loss: 3.1343\n",
            "trumpoott e etto eeeot eoe  a ot  eet  atoee  o e  ao te  tao  aeoo  aoeaoatee e     e      taa ao ae   o o  oeo ettte t   ete  a e  e\n",
            "Epoch: 115/120... Step: 5500... Loss: 3.1456... Val Loss: 3.1345\n",
            "trumpeteteto eo  atee e  a e  eaa    et  et eo aa  otteoe  at  aa eaetoe at aaoaeoettetet  e e ae ea  o o   oeeeet     e aoaa  t a  ot\n",
            "Epoch: 116/120... Step: 5550... Loss: 3.1638... Val Loss: 3.1346\n",
            "trumpeateooeota et  aoa to o oeoa ee ea et  te a eee   a a e oe tao  ee e o    oaa tot aetota tt ta  e tee   t  teoeeaaa t   tto a t t\n",
            "Epoch: 117/120... Step: 5600... Loss: 3.1570... Val Loss: 3.1348\n",
            "trumpoe e  oeotoa a ee etta eae tt  eeo etot et t otot eeoat ootea et t  a    ott ee ato aa   a  aa  ta aat a    ea t ateeo att ee   a\n",
            "Epoch: 118/120... Step: 5650... Loss: 3.1546... Val Loss: 3.1347\n",
            "trumpaeo o tt ttaeaoe t ttet a eoo  aa eae eoe   t oe a t     tt aee ta aoaa   oeaeaeo ao atot t etoeeeee  oo  ee  atoot oe aoeooeoe t\n",
            "Epoch: 119/120... Step: 5700... Loss: 3.1468... Val Loss: 3.1348\n",
            "trumpeoetttt eoao   e o  tooeaaoo  eot ttt oooaete e a eeeo   eao eeo aott  ae oea eoeaoet ae te ttoae  o ttatota   t  e a eoa to o t \n",
            "Epoch: 120/120... Step: 5750... Loss: 3.1514... Val Loss: 3.1349\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XwEphQu040cX"
      },
      "source": [
        "# Declaring a method to generate new text\n",
        "def sample(net, size, prime='The', top_k=None):\n",
        "        \n",
        "    if(train_on_gpu):\n",
        "        net.cuda()\n",
        "    else:\n",
        "        net.cpu()\n",
        "    \n",
        "    net.eval() # eval mode\n",
        "    \n",
        "    # First off, run through the prime characters\n",
        "    chars = [ch for ch in prime]\n",
        "    h = net.init_hidden(1)\n",
        "    for ch in prime:\n",
        "        char, h = predict(net, ch, h, top_k=top_k)\n",
        "\n",
        "    chars.append(char)\n",
        "    \n",
        "    # Now pass in the previous character and get a new one\n",
        "    for ii in range(size):\n",
        "        char, h = predict(net, chars[-1], h, top_k=top_k)\n",
        "        chars.append(char)\n",
        "\n",
        "    return ''.join(chars)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sA0nDYW053kZ",
        "outputId": "cdae73cc-5948-40fd-eb1f-756181bddd07"
      },
      "source": [
        "print(sample(net, 256, prime='hrc', top_k=5))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "hrc. who children? what is this relevant? what does at the special control? what was the present? why is the spash past for allowed by security of something to allow that people in the fund to being is the support to the mueller, and of past story and and such\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x-fi-OcV8fUA"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}