{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "IMDB.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "id": "N3xmkG71VKrd", "colab_type": "code", "outputId": "38d1dd5a-ba3c-4ab3-b73e-94dcccc67f9c", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "import os\n", "import sys\n", "import torch\n", "from torch.nn import functional as F\n", "import numpy as np\n", "from torchtext import data\n", "from torchtext import datasets\n", "from torchtext.vocab import Vectors, GloVe\n", "\n", "\n", "tokenize = lambda x: x.split()\n", "TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=True, batch_first=True, fix_length=200)\n", "LABEL = data.LabelField()\n", "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n", "TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=300))\n", "LABEL.build_vocab(train_data)\n", "word_embeddings = TEXT.vocab.vectors\n", "print (\"Length of Text Vocabulary: \" + str(len(TEXT.vocab)))\n", "print (\"Vector size of Text Vocabulary: \", TEXT.vocab.vectors.size())\n", "print (\"Label Length: \" + str(len(LABEL.vocab)))\n", "train_data, valid_data = train_data.split() # Further splitting of training_data to create new training_data & validation_data\n", "train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=32, sort_key=lambda x: len(x.text), repeat=False, shuffle=True)\n", "'''Alternatively we can also use the default configurations'''\n", "#train_iter_, test_iter_ = datasets.IMDB.iters(batch_size=32)\n", "vocab_size = len(TEXT.vocab)\n", "#return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Length of Text Vocabulary: 251639\n", "Vector size of Text Vocabulary: torch.Size([251639, 300])\n", "Label Length: 2\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "jlidBWmDqA7E", "colab_type": "code", "outputId": "dfe01c15-2571-4047-f26e-937e9d604281", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "61pzzXOlTnA1", "colab_type": "code", "colab": {} }, "source": [ "import os\n", "os.chdir(\"/content/drive/My Drive/Colab Notebooks/Optimization project\")\n", "os.getcwd()\n", "\n", "file_path = \"/content/drive/My Drive/Colab Notebooks/Optimization project/IMDB\"\n", "#directory = os.path.dirname(file_path)\n", "\n", "try:\n", " os.stat(file_path)\n", "except:\n", " os.mkdir(file_path) " ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Mj8HZFVVUilO", "colab_type": "code", "colab": {} }, "source": [ "import sug\n", "from sug import SUG" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "5g5-i7_HwZHz", "colab_type": "code", "colab": {} }, "source": [ "import torch\n", "from torch.optim import Optimizer\n", "import math\n", "import copy\n", "\n", "class SUG(Optimizer):\n", " def __init__(self, params, l_0, d_0=0, prob=1., eps=1e-4, momentum=0, dampening=0,\n", " weight_decay=0, nesterov=False):\n", " if l_0 < 0.0:\n", " raise ValueError(\"Invalid Lipsitz constant of gradient: {}\".format(l_0))\n", " if d_0 < 0.0:\n", " raise ValueError(\"Invalid disperion of gradient: {}\".format(d_0))\n", " if momentum < 0.0:\n", " raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n", " if weight_decay < 0.0:\n", " raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n", "\n", " defaults = dict(L=l_0, momentum=momentum, dampening=dampening,\n", " weight_decay=weight_decay, nesterov=nesterov)\n", " if nesterov and (momentum <= 0 or dampening != 0):\n", " raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n", " self.Lips = l_0\n", " self.prev_Lips = l_0\n", " self.D_0 = d_0\n", " self.eps = eps\n", " self.prob = prob\n", " self.start_param = params\n", " self.upd_sq_grad_norm = None\n", " self.sq_grad_norm = None\n", " self.loss = torch.tensor(0.)\n", " self.cur_loss = 0\n", " self.closure = None\n", " super(SUG, self).__init__(params, defaults)\n", "\n", " def __setstate__(self, state):\n", " super(SUG, self).__setstate__(state)\n", " for group in self.param_groups:\n", " group.setdefault('nesterov', False)\n", "\n", " def comp_batch_size(self):\n", " \"\"\"Returns optimal batch size for given d_0, eps and l_0;\n", "\n", " \"\"\"\n", " return math.ceil(2 * self.D_0 * self.eps / self.prev_Lips)\n", "\n", " def step(self, loss, closure):\n", " \"\"\"Performs a single optimization step.\n", "\n", " Arguments:\n", " loss : current loss\n", "\n", " closure (callable, optional): A closure that reevaluates the model\n", " and returns the loss.\n", " \"\"\"\n", " self.start_params = []\n", " self.loss = loss\n", " self.sq_grad_norm = 0\n", " self.cur_loss = loss\n", " self.closure = closure\n", " for gr_idx, group in enumerate(self.param_groups):\n", " weight_decay = group['weight_decay']\n", " momentum = group['momentum']\n", " dampening = group['dampening']\n", " nesterov = group['nesterov']\n", " self.start_params.append([])\n", " for p_idx, p in enumerate(group['params']):\n", " self.start_params[gr_idx].append([p.data.clone()])\n", " if p.grad is None:\n", " continue\n", " self.start_params[gr_idx][p_idx].append(p.grad.data.clone())\n", " d_p = self.start_params[gr_idx][p_idx][1]\n", " p_ = self.start_params[gr_idx][p_idx][0]\n", " \n", " \n", " if weight_decay != 0:\n", " d_p.add_(weight_decay, p.data)\n", " self.cur_loss += weight_decay * torch.sum(p * p).item()\n", " \n", " \n", " self.sq_grad_norm += torch.sum(d_p * d_p).item()\n", " \n", " if momentum != 0:\n", " param_state = self.state[p]\n", " if 'momentum_buffer' not in param_state:\n", " buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n", " buf.mul_(momentum).add_(d_p)\n", " else:\n", " buf = param_state['momentum_buffer']\n", " buf.mul_(momentum).add_(1 - dampening, d_p)\n", " if nesterov:\n", " d_p = d_p.add(momentum, buf)\n", " else:\n", " d_p = buf\n", " self.start_params[gr_idx][p_idx][1] = d_p\n", " \n", " i = 0\n", " self.Lips = max(self.prev_Lips / 2, 0.1)\n", " difference = -1\n", " while difference < 0 or i == 0:\n", " if (i > 0): \n", " self.Lips = max(self.Lips * 2, 0.1)\n", " for gr_idx, group in enumerate(self.param_groups):\n", " for p_idx, p in enumerate(group['params']):\n", " if p.grad is None:\n", " continue\n", " start_param_val = self.start_params[gr_idx][p_idx][0]\n", " start_param_grad = self.start_params[gr_idx][p_idx][1]\n", " p.data = start_param_val - 1/(2*self.Lips) * start_param_grad\n", " difference, upd_loss = self.stop_criteria()\n", " i += 1\n", " self.prev_Lips = self.Lips\n", "\n", " return self.Lips, i\n", "\n", " def stop_criteria(self):\n", " \"\"\"Checks if the Lipsitz constant of gradient is appropriate\n", " \n", " + 2L_k / 2 ||x_k - w_k||^2 = - 1 / (2L_k)||g(x_k)||^2 + 1 / (4L_k)||g(x_k)||^2 = -1 / (4L_k)||g(x_k)||^2 \n", " \"\"\"\n", " upd_loss = self.closure()\n", " major = self.cur_loss - 1 / (4 * self.Lips) * self.sq_grad_norm\n", " return major - upd_loss - self.l2_reg() + self.eps / 10, upd_loss\n", "\n", " def get_lipsitz_const(self):\n", " \"\"\"Returns current Lipsitz constant of the gradient of the loss function\n", " \"\"\"\n", " return self.Lips\n", " \n", " def get_sq_grad(self):\n", " \"\"\"Returns the current second norm of the gradient of the loss function \n", " calculated by the formula\n", " \n", " ||f'(p_1,...,p_n)||_2^2 ~ \\sum\\limits_{i=1}^n ((df/dp_i) * (df/dp_i))(p1,...,p_n))\n", " \n", " \"\"\"\n", " self.upd_sq_grad_norm = 0\n", " for gr_idx, group in enumerate(self.param_groups):\n", " for p_idx, p in enumerate(group['params']):\n", " if p.grad is None:\n", " continue\n", " self.upd_sq_grad_norm += torch.sum(p.grad.data * p.grad.data).item()\n", " \n", " return self.upd_sq_grad_norm\n", " \n", " def l2_reg(self):\n", " \"\"\"Returns the current l2 regularization addiction\n", " \n", " \"\"\"\n", " self.upd_l2_reg = 0\n", " for gr_idx, group in enumerate(self.param_groups):\n", " weight_decay = group['weight_decay']\n", " if weight_decay != 0:\n", " for p_idx, p in enumerate(group['params']):\n", " self.upd_l2_reg += weight_decay * torch.sum(p * p).item()\n", " \n", " return self.upd_l2_reg" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VwqXTkLZ4vDJ", "colab_type": "code", "outputId": "56671274-3efa-4d01-df1b-013eed2e7dc8", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "device" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "markdown", "metadata": { "id": "zHQxrPhc9g4k", "colab_type": "text" }, "source": [ "## Model" ] }, { "cell_type": "code", "metadata": { "id": "KDBOEG6tVyuR", "colab_type": "code", "colab": {} }, "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.autograd import Variable\n", "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "D-WcYYATWwA3", "colab_type": "code", "colab": {} }, "source": [ "class SimpleLSTMBaseline(nn.Module):\n", " def __init__(self, hidden_dim, emb_dim=300, num_linear=1):\n", " super().__init__() \n", " self.embedding = nn.Embedding(len(TEXT.vocab), emb_dim)\n", " self.encoder = nn.LSTM(emb_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)\n", " \n", " self.linear1 = nn.Linear(2 * hidden_dim, 32)\n", " self.linear1.weight.data.fill_(2)\n", " self.linear2 = nn.Linear(32, 2)\n", " self.linear2.weight.data.fill_(2)\n", "\n", " \n", " def forward(self, seq, lens):\n", " embeds = self.embedding(seq)\n", " packed = pack_padded_sequence(embeds, lens, batch_first=True)\n", " hdn, _ = self.encoder(packed)\n", " hdn, _ = pad_packed_sequence(hdn, batch_first=True)\n", " output = nn.functional.max_pool1d(hdn, kernel_size=10)\n", " output = nn.functional.relu(self.linear1(hdn[:,1,:]))\n", " prob = nn.functional.log_softmax(self.linear2(output), -1)\n", " \n", " return prob\n", " " ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "nXtqKQJfuexo", "colab_type": "code", "colab": {} }, "source": [ "class LSTMClassifier(nn.Module):\n", " def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):\n", " super(LSTMClassifier, self).__init__()\n", " \n", " \"\"\"\n", " Arguments\n", " ---------\n", " batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator\n", " output_size : 2 = (pos, neg)\n", " hidden_sie : Size of the hidden_state of the LSTM\n", " vocab_size : Size of the vocabulary containing unique words\n", " embedding_length : Embeddding dimension of GloVe word embeddings\n", " weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table \n", " \n", " \"\"\"\n", " \n", " self.batch_size = batch_size\n", " self.output_size = output_size\n", " self.hidden_size = hidden_size\n", " self.vocab_size = vocab_size\n", " self.embedding_length = embedding_length\n", " self.num_layers = 1\n", " \n", " self.word_embeddings = nn.Embedding(vocab_size, embedding_length)# Initializing the look-up table.\n", " self.word_embeddings.weight = nn.Parameter(weights, requires_grad=False) # Assigning the look-up table to the pre-trained GloVe word embedding.\n", " self.lstm = nn.LSTM(embedding_length, hidden_size, batch_first=True, bidirectional=False, num_layers=self.num_layers)\n", " self.label = nn.Linear(1 * hidden_size * self.num_layers, output_size)\n", " \n", " def forward(self, input_sentence, batch_size=None):\n", "\n", " \"\"\" \n", " Parameters\n", " ----------\n", " input_sentence: input_sentence of shape = (batch_size, num_sequences)\n", " batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1)\n", " \n", " Returns\n", " -------\n", " Output of the linear layer containing logits for positive & negative class which receives its input as the final_hidden_state of the LSTM\n", " final_output.shape = (batch_size, output_size)\n", " \n", " \"\"\"\n", " \n", " ''' Here we will map all the indexes present in the input sequence to the corresponding word vector using our pre-trained word_embedddins.'''\n", " input = self.word_embeddings(input_sentence) # embedded input of shape = (batch_size, num_sequences, embedding_length)\n", " #input = input.permute(1, 0, 2) # input.size() = (num_sequences, batch_size, embedding_length)\n", " batch_size = input_sentence.size(0)\n", " h_0 = Variable(torch.zeros(1 * self.num_layers, batch_size, self.hidden_size).cuda())\n", " c_0 = Variable(torch.zeros(1 * self.num_layers, batch_size, self.hidden_size).cuda())\n", " #packed = pack_padded_sequence(input, lens, batch_first=True)\n", " #output, (final_hidden_state, final_cell_state) = self.lstm(packed, (h_0, c_0))\n", " #output, _ = pad_packed_sequence(output, batch_first=True)\n", " #print(input.size())\n", " output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))\n", " final_output = self.label(final_hidden_state.view(batch_size, self.num_layers*1*self.hidden_size)) # final_hidden_state.size() = (2, batch_size, hidden_size) & final_output.size() = (batch_size, output_size)\n", " \n", " return final_output" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "rQBczvQqhFQR", "colab_type": "code", "colab": {} }, "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "from torch.nn import functional as F\n", "import numpy as np\n", "\n", "class AttentionModel(torch.nn.Module):\n", "\tdef __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):\n", "\t\tsuper(AttentionModel, self).__init__()\n", "\t\t\n", "\t\t\"\"\"\n", "\t\tArguments\n", "\t\t---------\n", "\t\tbatch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator\n", "\t\toutput_size : 2 = (pos, neg)\n", "\t\thidden_sie : Size of the hidden_state of the LSTM\n", "\t\tvocab_size : Size of the vocabulary containing unique words\n", "\t\tembedding_length : Embeddding dimension of GloVe word embeddings\n", "\t\tweights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table \n", "\t\t\n", "\t\t--------\n", "\t\t\n", "\t\t\"\"\"\n", "\t\t\n", "\t\tself.batch_size = batch_size\n", "\t\tself.output_size = output_size\n", "\t\tself.hidden_size = hidden_size\n", "\t\tself.vocab_size = vocab_size\n", "\t\tself.embedding_length = embedding_length\n", "\t\t\n", "\t\tself.word_embeddings = nn.Embedding(vocab_size, embedding_length)\n", "\t\tself.word_embeddings.weights = nn.Parameter(weights, requires_grad=False)\n", "\t\tself.lstm = nn.LSTM(embedding_length, hidden_size)\n", "\t\tself.label = nn.Linear(hidden_size, output_size)\n", "\t\t#self.attn_fc_layer = nn.Linear()\n", "\t\t\n", "\tdef attention_net(self, lstm_output, final_state):\n", "\n", "\t\t\"\"\" \n", "\t\tNow we will incorporate Attention mechanism in our LSTM model. In this new model, we will use attention to compute soft alignment score corresponding\n", "\t\tbetween each of the hidden_state and the last hidden_state of the LSTM. We will be using torch.bmm for the batch matrix multiplication.\n", "\t\t\n", "\t\tArguments\n", "\t\t---------\n", "\t\t\n", "\t\tlstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence.\n", "\t\tfinal_state : Final time-step hidden state (h_n) of the LSTM\n", "\t\t\n", "\t\t---------\n", "\t\t\n", "\t\tReturns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the\n", "\t\t\t\t new hidden state.\n", "\t\t\t\t \n", "\t\tTensor Size :\n", "\t\t\t\t\thidden.size() = (batch_size, hidden_size)\n", "\t\t\t\t\tattn_weights.size() = (batch_size, num_seq)\n", "\t\t\t\t\tsoft_attn_weights.size() = (batch_size, num_seq)\n", "\t\t\t\t\tnew_hidden_state.size() = (batch_size, hidden_size)\n", "\t\t\t\t\t \n", "\t\t\"\"\"\n", "\t\t\n", "\t\thidden = final_state.squeeze(0)\n", "\t\tattn_weights = torch.bmm(lstm_output, hidden.unsqueeze(2)).squeeze(2)\n", "\t\tsoft_attn_weights = F.softmax(attn_weights, 1)\n", "\t\tnew_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)\n", "\t\t\n", "\t\treturn new_hidden_state\n", "\t\n", "\tdef forward(self, input_sentences, batch_size=None):\n", "\t\n", "\t\t\"\"\" \n", "\t\tParameters\n", "\t\t----------\n", "\t\tinput_sentence: input_sentence of shape = (batch_size, num_sequences)\n", "\t\tbatch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1)\n", "\t\t\n", "\t\tReturns\n", "\t\t-------\n", "\t\tOutput of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network.\n", "\t\tfinal_output.shape = (batch_size, output_size)\n", "\t\t\n", "\t\t\"\"\"\n", "\t\tbatch_size = input_sentences.size(0)\n", "\t\tinput = self.word_embeddings(input_sentences)\n", "\t\tinput = input.permute(1, 0, 2)\n", "\t\th_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n", "\t\tc_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda())\n", "\t\toutput, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) # final_hidden_state.size() = (1, batch_size, hidden_size) \n", "\t\toutput = output.permute(1, 0, 2) # output.size() = (batch_size, num_seq, hidden_size)\n", "\t\t\n", "\t\tattn_output = self.attention_net(output, final_hidden_state)\n", "\t\tlogits = self.label(attn_output)\n", "\t\t\n", "\t\treturn nn.functional.log_softmax(logits, -1)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "3VbVMHjn9jlN", "colab_type": "text" }, "source": [ "## Training" ] }, { "cell_type": "code", "metadata": { "id": "QcDvlJrWh_0X", "colab_type": "code", "colab": {} }, "source": [ "import time\n", "import math\n", "\n", "def time_since(since):\n", " s = time.time() - since\n", " m = math.floor(s / 60)\n", " s -= m * 60\n", " return '%dm %ds' % (m, s)\n", "\n", "\n", "def model_step(model, optimizer, criterion, inputs, labels):\n", " outputs = model(inputs)\n", " loss = criterion(outputs, labels)\n", " acc = (torch.argmax(outputs, 1) == labels).float().sum().item()\n", " if model.training:\n", " optimizer.zero_grad()\n", " loss.backward(retain_graph=True)\n", " if optimizer.__class__.__name__ != 'SUG':\n", " optimizer.step()\n", " else:\n", " def closure():\n", " optimizer.zero_grad()\n", " upd_outputs = model(inputs)\n", " upd_loss = criterion(upd_outputs, labels).item()\n", "\n", " return upd_loss\n", "\n", " optimizer.step(loss.item(), closure)\n", "\n", " return loss.item(), acc" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KXe4_d-Ziami", "colab_type": "code", "colab": {} }, "source": [ "def train(model, trainloader, criterion, optimizer, path=None, n_epochs=2, validloader=None, eps=1e-5, print_every=1):\n", " tr_loss, val_loss, lips, times, grad, tr_acc, val_acc = ([] for i in range(7))\n", " start_time = time.time()\n", " model.to(device=device)\n", " print(len(list(trainloader)))\n", " for ep in range(n_epochs):\n", " model.train()\n", " i = 0\n", " tot_acc = 0\n", " n_ex = 0\n", " for i, batch in enumerate(trainloader):\n", " #t, l = batch\n", " #(text, lens), target = t\n", " text = batch.text[0]\n", " lens = batch.text[1]\n", " target = batch.label\n", " target = torch.autograd.Variable(target).long()\n", " if torch.cuda.is_available():\n", " text = text.cuda()\n", " target = target.cuda()\n", " loss, acc = model_step(model, optimizer, criterion, text, target)\n", " tot_acc += acc\n", " n_ex += text.size(0)\n", " tr_loss.append(loss) \n", " \n", " if optimizer.__class__.__name__ == 'SUG':\n", " lips.append(optimizer.get_lipsitz_const())\n", " grad.append(optimizer.get_sq_grad)\n", " if i % 100 == 0:\n", " print(tr_loss[-1], i)\n", " times.append(time_since(start_time))\n", " model.zero_grad()\n", " optimizer.zero_grad()\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss,\n", " 'val_loss' : val_loss,\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times\n", " } \n", " if path is not None:\n", " torch.save(states, path)\n", " tr_acc.append(tot_acc / n_ex)\n", " times.append(time_since(start_time))\n", " if ep % print_every == 0:\n", " print(\"Epoch {}, training loss {}, time passed {}, training accuracy {}\".format(ep, sum(tr_loss[-i:]) / i, time_since(start_time), tr_acc[-1]))\n", "\n", " if validloader is None:\n", " continue\n", " model.zero_grad()\n", " model.eval()\n", " j = 0\n", " count = 0\n", " n_ex = 0\n", " for j, batch in enumerate(validloader):\n", " text = batch.text[0]\n", " target = batch.label\n", " target = torch.autograd.Variable(target).long()\n", " if torch.cuda.is_available():\n", " text = text.cuda()\n", " target = target.cuda()\n", " outputs = model(text)\n", " #outputs_lab = torch.argmax(outputs, 1)\n", " count += (torch.argmax(outputs, 1) == target).float().sum().item()\n", " n_ex += outputs.size(0) \n", " val_loss.append(criterion(outputs, target).item())\n", " val_acc.append(count / n_ex)\n", " if ep % print_every == 0:\n", " print(\"Validation loss {}, validation accuracy {}\".format(sum(val_loss[-j:]) / j, val_acc[-1]))\n", " \n", " return tr_loss, times, val_loss, lips, grad, tr_acc, val_acc" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "mvhnRhO8ifa4", "colab_type": "code", "colab": {} }, "source": [ "def concat_states(state1, state2):\n", " states = {\n", " 'epoch': state1['epoch'] + state2['epoch'],\n", " 'state_dict': state2['state_dict'],\n", " 'optimizer': state2['optimizer'],\n", " 'tr_loss' : state1['tr_loss'] + state2['tr_loss'],\n", " 'val_loss' : state1['val_loss'] + state2['val_loss'],\n", " 'lips' : state1['lips'] + state2['lips'],\n", " 'grad' : state1['grad'] + state2['grad'],\n", " #'times' : state1['times'] + list(map(lambda x: x + state1['times'][-1],state2['times']))\n", " 'times' : state1['times'] + state2['times']\n", " }\n", " return states" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "PHhBIzadidIY", "colab_type": "code", "colab": {} }, "source": [ "print_every = 1\n", "n_epochs = 10\n", "tr_loss = {}\n", "tr_loss['sgd'] = {}\n", "val_loss = {}\n", "val_loss['sgd'] = {}\n", "#lrs = [0.05, 0.01, 0.005]\n", "em_sz = 128\n", "hidden_size = 256\n", "embedding_length = 300\n", "nl = 2\n", "torch.manual_seed(999)\n", "batch_size = 32\n", "criterion = nn.CrossEntropyLoss()" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "4an04kBXvHRn", "colab_type": "code", "colab": {} }, "source": [ "n_epochs = 20" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Fwvpjn-X2-F9", "colab_type": "code", "outputId": "4cf19ade-d1b8-49fd-ec5d-864a95d3525c", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "vocab_size = int(vocab_size)\n", "vocab_size" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "251639" ] }, "metadata": { "tags": [] }, "execution_count": 17 } ] }, { "cell_type": "code", "metadata": { "id": "bhytd4Y-iiKH", "colab_type": "code", "outputId": "f65ca63a-18ef-4925-ff16-1b1ab898e141", "colab": { "base_uri": "https://localhost:8080/", "height": 2805 } }, "source": [ "lrs = [0.0001, 0.001]\n", "for lr in lrs:\n", " model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/LSTM_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.0001, momentum=0. :\n", "547\n", "0.6964250206947327 0\n", "0.694038987159729 100\n", "0.690762996673584 200\n", "0.6946480870246887 300\n", "0.6906776428222656 400\n", "0.6938506364822388 500\n", "Epoch 0, training loss 0.6931881379513514, time passed 0m 17s, training accuracy 0.5025714285714286\n", "Validation loss 0.6940929013439733, validation accuracy 0.4856\n", "0.6945292353630066 0\n", "0.69283527135849 100\n", "0.6948346495628357 200\n", "0.6958023905754089 300\n", "0.688003420829773 400\n", "0.6869951486587524 500\n", "Epoch 1, training loss 0.6931847904846345, time passed 0m 26s, training accuracy 0.5028\n", "Validation loss 0.694080766194906, validation accuracy 0.48546666666666666\n", "0.6985771656036377 0\n", "0.6874481439590454 100\n", "0.6930689215660095 200\n", "0.6923878192901611 300\n", "0.693231463432312 400\n", "0.693196713924408 500\n", "Epoch 2, training loss 0.6931726265521276, time passed 0m 36s, training accuracy 0.5028571428571429\n", "Validation loss 0.6940687321699582, validation accuracy 0.486\n", "0.6972582340240479 0\n", "0.690631628036499 100\n", "0.6858733892440796 200\n", "0.697816014289856 300\n", "0.6938906311988831 400\n", "0.6959564089775085 500\n", "Epoch 3, training loss 0.6931685177616147, time passed 0m 46s, training accuracy 0.5026857142857143\n", "Validation loss 0.6940567034941453, validation accuracy 0.48586666666666667\n", "0.6890140771865845 0\n", "0.6966797113418579 100\n", "0.6985792517662048 200\n", "0.6887302994728088 300\n", "0.6915265917778015 400\n", "0.6963625550270081 500\n", "Epoch 4, training loss 0.6931770691723177, time passed 0m 56s, training accuracy 0.5026857142857143\n", "Validation loss 0.6940450647957305, validation accuracy 0.486\n", "0.6946545839309692 0\n", "0.69691401720047 100\n", "0.6969871520996094 200\n", "0.6931300759315491 300\n", "0.6894280314445496 400\n", "0.6944723725318909 500\n", "Epoch 5, training loss 0.6931609708965917, time passed 1m 7s, training accuracy 0.5025142857142857\n", "Validation loss 0.6940335315516871, validation accuracy 0.48573333333333335\n", "0.6953996419906616 0\n", "0.6925109028816223 100\n", "0.6946840286254883 200\n", "0.6963344216346741 300\n", "0.6918807625770569 400\n", "0.696747362613678 500\n", "Epoch 6, training loss 0.6931529439194298, time passed 1m 17s, training accuracy 0.5025714285714286\n", "Validation loss 0.6940224534935422, validation accuracy 0.48546666666666666\n", "0.6951776742935181 0\n", "0.691841721534729 100\n", "0.6929687857627869 200\n", "0.6960068941116333 300\n", "0.6971116065979004 400\n", "0.693190336227417 500\n", "Epoch 7, training loss 0.6931483859107608, time passed 1m 27s, training accuracy 0.5029142857142858\n", "Validation loss 0.6940113030947171, validation accuracy 0.48573333333333335\n", "0.696017861366272 0\n", "0.7012372612953186 100\n", "0.6901893019676208 200\n", "0.6899189949035645 300\n", "0.6876844167709351 400\n", "0.6932987570762634 500\n", "Epoch 8, training loss 0.6931420001354847, time passed 1m 38s, training accuracy 0.5030857142857142\n", "Validation loss 0.694000345774186, validation accuracy 0.48573333333333335\n", "0.6958585381507874 0\n", "0.6899902820587158 100\n", "0.689932644367218 200\n", "0.6935151219367981 300\n", "0.689987063407898 400\n", "0.6959187984466553 500\n", "Epoch 9, training loss 0.693135427467989, time passed 1m 48s, training accuracy 0.5026857142857143\n", "Validation loss 0.6939898209694104, validation accuracy 0.486\n", "SGD lr=0.001, momentum=0. :\n", "547\n", "0.6835296154022217 0\n", "0.6975476741790771 100\n", "0.7016441822052002 200\n", "0.6932113170623779 300\n", "0.6852632761001587 400\n", "0.6923226714134216 500\n", "Epoch 0, training loss 0.6935323885072282, time passed 0m 10s, training accuracy 0.5019428571428571\n", "Validation loss 0.6925257752593766, validation accuracy 0.5218666666666667\n", "0.6928713917732239 0\n", "0.6912435293197632 100\n", "0.6879763007164001 200\n", "0.6958819031715393 300\n", "0.6908573508262634 400\n", "0.6929187774658203 500\n", "Epoch 1, training loss 0.6933078717836093, time passed 0m 20s, training accuracy 0.49897142857142857\n", "Validation loss 0.6925921172667773, validation accuracy 0.52\n", "0.6950969696044922 0\n", "0.6899083852767944 100\n", "0.6981430053710938 200\n", "0.6893796920776367 300\n", "0.6957111358642578 400\n", "0.7011501789093018 500\n", "Epoch 2, training loss 0.6931950786174872, time passed 0m 30s, training accuracy 0.4993714285714286\n", "Validation loss 0.6926529804865519, validation accuracy 0.5177333333333334\n", "0.6897975206375122 0\n", "0.6957585215568542 100\n", "0.693533718585968 200\n", "0.6910946369171143 300\n", "0.6954749822616577 400\n", "0.6886304020881653 500\n", "Epoch 3, training loss 0.6931292296765925, time passed 0m 41s, training accuracy 0.5003428571428571\n", "Validation loss 0.6927008409785409, validation accuracy 0.5149333333333334\n", "0.6912845969200134 0\n", "0.6937668919563293 100\n", "0.6902360916137695 200\n", "0.6908919811248779 300\n", "0.6905723810195923 400\n", "0.6907973885536194 500\n", "Epoch 4, training loss 0.6930708169063806, time passed 0m 51s, training accuracy 0.5004571428571428\n", "Validation loss 0.6927197351414933, validation accuracy 0.5165333333333333\n", "0.6970768570899963 0\n", "0.6905248165130615 100\n", "0.6949999332427979 200\n", "0.6907947659492493 300\n", "0.6962170004844666 400\n", "0.6864935755729675 500\n", "Epoch 5, training loss 0.6930131226668864, time passed 1m 1s, training accuracy 0.5007428571428572\n", "Validation loss 0.6927327926342304, validation accuracy 0.5182666666666667\n", "0.6914203763008118 0\n", "0.6898652911186218 100\n", "0.6895883083343506 200\n", "0.6859859824180603 300\n", "0.6931174397468567 400\n", "0.6964842081069946 500\n", "Epoch 6, training loss 0.6929789867811588, time passed 1m 11s, training accuracy 0.5009142857142858\n", "Validation loss 0.6927546295854781, validation accuracy 0.5189333333333334\n", "0.6915695667266846 0\n", "0.688200831413269 100\n", "0.6888987421989441 200\n", "0.6924067735671997 300\n", "0.6952000856399536 400\n", "0.6942896842956543 500\n", "Epoch 7, training loss 0.6929369884314555, time passed 1m 22s, training accuracy 0.5005714285714286\n", "Validation loss 0.6927364764050541, validation accuracy 0.5197333333333334\n", "0.6924756169319153 0\n", "0.6936146020889282 100\n", "0.6935449838638306 200\n", "0.6891248226165771 300\n", "0.6895835399627686 400\n", "0.6902214288711548 500\n", "Epoch 8, training loss 0.6928978291623321, time passed 1m 32s, training accuracy 0.5004571428571428\n", "Validation loss 0.6927179087940444, validation accuracy 0.5202666666666667\n", "0.6919939517974854 0\n", "0.6884726285934448 100\n", "0.6827298998832703 200\n", "0.6916747093200684 300\n", "0.6918917894363403 400\n", "0.7025838494300842 500\n", "Epoch 9, training loss 0.6928599404764699, time passed 1m 42s, training accuracy 0.5010857142857142\n", "Validation loss 0.6927168697882922, validation accuracy 0.5218666666666667\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "bdwWe20ikVEU", "colab_type": "code", "outputId": "df02db2b-dcf8-4a03-b0ca-605ddcae5695", "colab": { "base_uri": "https://localhost:8080/", "height": 2825 } }, "source": [ "l_0 = 20\n", "model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", "print(\"SUG l_0={}, momentum=0. :\".format(l_0))\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", "torch.save(states, './IMDB/LSTM_sug')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SUG l_0=20, momentum=0. :\n", "547\n", "0.6849241256713867 0\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:522: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n", " self.dropout, self.training, self.bidirectional, self.batch_first)\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "0.6952140927314758 100\n", "0.7481666803359985 200\n", "0.6844102144241333 300\n", "0.6859437227249146 400\n", "0.7077557444572449 500\n", "Epoch 0, training loss 0.7053052665113093, time passed 0m 20s, training accuracy 0.5095428571428572\n", "Validation loss 0.7462471091849172, validation accuracy 0.5014666666666666\n", "0.7905590534210205 0\n", "0.6987895965576172 100\n", "0.7223014831542969 200\n", "0.7111559510231018 300\n", "0.742397665977478 400\n", "0.6616271138191223 500\n", "Epoch 1, training loss 0.7009164982444638, time passed 0m 38s, training accuracy 0.5201714285714286\n", "Validation loss 0.6874387786429152, validation accuracy 0.5314666666666666\n", "0.6852426528930664 0\n", "0.7327583432197571 100\n", "0.6913496851921082 200\n", "0.6588264107704163 300\n", "0.5595537424087524 400\n", "0.5827333927154541 500\n", "Epoch 2, training loss 0.6299851954856635, time passed 0m 56s, training accuracy 0.6525714285714286\n", "Validation loss 0.5988702120689245, validation accuracy 0.7065333333333333\n", "0.5051348209381104 0\n", "0.6458197236061096 100\n", "0.6327732801437378 200\n", "0.6482389569282532 300\n", "0.5426529049873352 400\n", "0.6354039907455444 500\n", "Epoch 3, training loss 0.5837856869012008, time passed 1m 14s, training accuracy 0.7268571428571429\n", "Validation loss 0.5769952004536604, validation accuracy 0.7372\n", "0.6321955919265747 0\n", "0.6145216822624207 100\n", "0.6611451506614685 200\n", "0.5769146680831909 300\n", "0.7109729647636414 400\n", "0.6740754246711731 500\n", "Epoch 4, training loss 0.5727979150010553, time passed 1m 33s, training accuracy 0.7321714285714286\n", "Validation loss 0.5685298531992823, validation accuracy 0.7350666666666666\n", "0.5007982850074768 0\n", "0.5335962176322937 100\n", "0.6270943284034729 200\n", "0.6070114374160767 300\n", "0.6142672300338745 400\n", "0.5168774127960205 500\n", "Epoch 5, training loss 0.5623274704803041, time passed 1m 51s, training accuracy 0.7411428571428571\n", "Validation loss 0.556505321055396, validation accuracy 0.7457333333333334\n", "0.45171988010406494 0\n", "0.6062171459197998 100\n", "0.44935569167137146 200\n", "0.3782166838645935 300\n", "0.6315807700157166 400\n", "0.5480033159255981 500\n", "Epoch 6, training loss 0.5602466128902994, time passed 2m 9s, training accuracy 0.7430857142857142\n", "Validation loss 0.5531672547515641, validation accuracy 0.7502666666666666\n", "0.40330076217651367 0\n", "0.5502031445503235 100\n", "0.5512714982032776 200\n", "0.5983728170394897 300\n", "0.48301962018013 400\n", "0.43600472807884216 500\n", "Epoch 7, training loss 0.5599638384122115, time passed 2m 27s, training accuracy 0.7456\n", "Validation loss 0.5597015177337532, validation accuracy 0.7410666666666667\n", "0.5992159247398376 0\n", "0.5661896467208862 100\n", "0.4971376657485962 200\n", "0.5945307612419128 300\n", "0.48792123794555664 400\n", "0.5712702870368958 500\n", "Epoch 8, training loss 0.5511544310238772, time passed 2m 45s, training accuracy 0.7526857142857143\n", "Validation loss 0.6061686902728856, validation accuracy 0.7177333333333333\n", "0.5343213081359863 0\n", "0.3581521511077881 100\n", "0.4510408341884613 200\n", "0.5065832734107971 300\n", "0.5091931223869324 400\n", "0.509299635887146 500\n", "Epoch 9, training loss 0.5370824536233595, time passed 3m 3s, training accuracy 0.7604571428571428\n", "Validation loss 0.5452731877055943, validation accuracy 0.7505333333333334\n", "0.5403605699539185 0\n", "0.657295286655426 100\n", "0.4403851628303528 200\n", "0.4891587197780609 300\n", "0.45189347863197327 400\n", "0.4263607859611511 500\n", "Epoch 10, training loss 0.527747243588224, time passed 3m 21s, training accuracy 0.7691428571428571\n", "Validation loss 0.5223511960516628, validation accuracy 0.7714666666666666\n", "0.3096063733100891 0\n", "0.5367862582206726 100\n", "0.5144371390342712 200\n", "0.4614124894142151 300\n", "0.48254287242889404 400\n", "0.5188775658607483 500\n", "Epoch 11, training loss 0.5240872860937328, time passed 3m 40s, training accuracy 0.7701142857142858\n", "Validation loss 0.5141215645349942, validation accuracy 0.7785333333333333\n", "0.6413094997406006 0\n", "0.5476253628730774 100\n", "0.4742453098297119 200\n", "0.6802095174789429 300\n", "0.34964802861213684 400\n", "0.46939337253570557 500\n", "Epoch 12, training loss 0.51118186378217, time passed 3m 58s, training accuracy 0.7798857142857143\n", "Validation loss 0.5072192745840448, validation accuracy 0.7862666666666667\n", "0.46523165702819824 0\n", "0.5808182954788208 100\n", "0.43049943447113037 200\n", "0.40775153040885925 300\n", "0.46464669704437256 400\n", "0.48825064301490784 500\n", "Epoch 13, training loss 0.5027162516728426, time passed 4m 16s, training accuracy 0.7845142857142857\n", "Validation loss 0.4992470221641736, validation accuracy 0.7897333333333333\n", "0.5519555807113647 0\n", "0.5606043338775635 100\n", "0.46889716386795044 200\n", "0.4359128475189209 300\n", "0.5524762868881226 400\n", "0.4447726905345917 500\n", "Epoch 14, training loss 0.4941340588278823, time passed 4m 34s, training accuracy 0.7893142857142857\n", "Validation loss 0.4892241461918904, validation accuracy 0.7922666666666667\n", "0.610282301902771 0\n", "0.3217718303203583 100\n", "0.4827372133731842 200\n", "0.46589168906211853 300\n", "0.45471060276031494 400\n", "0.46282458305358887 500\n", "Epoch 15, training loss 0.48457636877948984, time passed 4m 52s, training accuracy 0.7946857142857143\n", "Validation loss 0.48729967930887497, validation accuracy 0.7946666666666666\n", "0.578451931476593 0\n", "0.3851317763328552 100\n", "0.5418026447296143 200\n", "0.36822280287742615 300\n", "0.5331905484199524 400\n", "0.5257861018180847 500\n", "Epoch 16, training loss 0.4773170604701444, time passed 5m 10s, training accuracy 0.7992571428571429\n", "Validation loss 0.5018188502544012, validation accuracy 0.7857333333333333\n", "0.5560949444770813 0\n", "0.5609979629516602 100\n", "0.2881567180156708 200\n", "0.763985812664032 300\n", "0.6070953011512756 400\n", "0.37004733085632324 500\n", "Epoch 17, training loss 0.46834920199362784, time passed 5m 28s, training accuracy 0.8027428571428571\n", "Validation loss 0.4715560737710733, validation accuracy 0.7990666666666667\n", "0.6463710069656372 0\n", "0.4798555374145508 100\n", "0.44754114747047424 200\n", "0.4177989959716797 300\n", "0.47894716262817383 400\n", "0.5259878635406494 500\n", "Epoch 18, training loss 0.46406372625640024, time passed 5m 46s, training accuracy 0.8025142857142857\n", "Validation loss 0.4985135881564556, validation accuracy 0.7785333333333333\n", "0.740950345993042 0\n", "0.46650639176368713 100\n", "0.49160894751548767 200\n", "0.6081886291503906 300\n", "0.42985838651657104 400\n", "0.3773285448551178 500\n", "Epoch 19, training loss 0.4523882256486477, time passed 6m 4s, training accuracy 0.8074285714285714\n", "Validation loss 0.4529664355974931, validation accuracy 0.8045333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "4WmDxq7dDHdZ", "colab_type": "code", "outputId": "421f0bfe-de81-4797-f5c5-419146edde93", "colab": { "base_uri": "https://localhost:8080/", "height": 1411 } }, "source": [ "lrs = [0.01]\n", "for lr in lrs:\n", " model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/LSTM_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.01, momentum=0. :\n", "547\n", "0.6858300566673279 0\n", "0.6929423213005066 100\n", "0.6985231041908264 200\n", "0.7083036303520203 300\n", "0.7047500014305115 400\n", "0.7050995826721191 500\n", "Epoch 0, training loss 0.6928785365798098, time passed 0m 9s, training accuracy 0.5070285714285714\n", "Validation loss 0.6905961920562972, validation accuracy 0.5312\n", "0.688129186630249 0\n", "0.6874046921730042 100\n", "0.691518247127533 200\n", "0.6922094225883484 300\n", "0.7087409496307373 400\n", "0.6636379361152649 500\n", "Epoch 1, training loss 0.691004237521699, time passed 0m 20s, training accuracy 0.5172\n", "Validation loss 0.6904787467076228, validation accuracy 0.5094666666666666\n", "0.6732712984085083 0\n", "0.6795019507408142 100\n", "0.7093783617019653 200\n", "0.6702148914337158 300\n", "0.6914629340171814 400\n", "0.679591715335846 500\n", "Epoch 2, training loss 0.6890987443836617, time passed 0m 30s, training accuracy 0.5206285714285714\n", "Validation loss 0.6867884596188863, validation accuracy 0.5437333333333333\n", "0.67351233959198 0\n", "0.7418563961982727 100\n", "0.6954500675201416 200\n", "0.6750000715255737 300\n", "0.6497403383255005 400\n", "0.6969770193099976 500\n", "Epoch 3, training loss 0.6867323645523616, time passed 0m 40s, training accuracy 0.5280571428571429\n", "Validation loss 0.6849217463252891, validation accuracy 0.5268\n", "0.6905239820480347 0\n", "0.7234190106391907 100\n", "0.6550784111022949 200\n", "0.6806483864784241 300\n", "0.7010015845298767 400\n", "0.6969113349914551 500\n", "Epoch 4, training loss 0.6836398839950562, time passed 0m 51s, training accuracy 0.5396571428571428\n", "Validation loss 0.679884391462701, validation accuracy 0.5356\n", "0.6986691951751709 0\n", "0.6529309153556824 100\n", "0.639856219291687 200\n", "0.6375672221183777 300\n", "0.6711844801902771 400\n", "0.6860522031784058 500\n", "Epoch 5, training loss 0.6754605933860108, time passed 1m 1s, training accuracy 0.5683428571428571\n", "Validation loss 0.6923773884773254, validation accuracy 0.5188\n", "0.6988645195960999 0\n", "0.6917654275894165 100\n", "0.7053704261779785 200\n", "0.7139018774032593 300\n", "0.6837825179100037 400\n", "0.6800549626350403 500\n", "Epoch 6, training loss 0.6938841290526337, time passed 1m 11s, training accuracy 0.5063428571428571\n", "Validation loss 0.6901481602436457, validation accuracy 0.5352\n", "0.6856598258018494 0\n", "0.6905097961425781 100\n", "0.6995676755905151 200\n", "0.6977699398994446 300\n", "0.6822452545166016 400\n", "0.6835910677909851 500\n", "Epoch 7, training loss 0.6903029076564007, time passed 1m 21s, training accuracy 0.5186857142857143\n", "Validation loss 0.6897223377839113, validation accuracy 0.5073333333333333\n", "0.6816807389259338 0\n", "0.675635039806366 100\n", "0.6933707594871521 200\n", "0.6894451975822449 300\n", "0.6849492788314819 400\n", "0.6990582346916199 500\n", "Epoch 8, training loss 0.6889166395306151, time passed 1m 32s, training accuracy 0.5239428571428572\n", "Validation loss 0.6875857746499217, validation accuracy 0.5421333333333334\n", "0.6816436648368835 0\n", "0.6997812986373901 100\n", "0.6829116344451904 200\n", "0.7450448274612427 300\n", "0.6575583219528198 400\n", "0.6812006831169128 500\n", "Epoch 9, training loss 0.6854374798444601, time passed 1m 42s, training accuracy 0.5316\n", "Validation loss 0.6869355847692897, validation accuracy 0.5241333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "fD5VAu71YkHV", "colab_type": "code", "outputId": "a9592bc5-7523-4214-ab4a-7dbdb365c67f", "colab": { "base_uri": "https://localhost:8080/", "height": 2771 } }, "source": [ "lrs = [0.001]\n", "for lr in lrs:\n", " model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.Adam(model.parameters(), lr=lr)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/LSTM_adam_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.001, momentum=0. :\n", "547\n", "0.7042360901832581 0\n", "0.7139124274253845 100\n", "0.648885190486908 200\n", "0.6860226392745972 300\n", "0.685365617275238 400\n", "0.6789942383766174 500\n", "Epoch 0, training loss 0.6891226628761151, time passed 0m 10s, training accuracy 0.5336\n", "Validation loss 0.6916678978337182, validation accuracy 0.5101333333333333\n", "0.684501051902771 0\n", "0.6671040058135986 100\n", "0.6860296130180359 200\n", "0.6729943156242371 300\n", "0.6892271041870117 400\n", "0.7105855941772461 500\n", "Epoch 1, training loss 0.6802897323946376, time passed 0m 22s, training accuracy 0.5594857142857143\n", "Validation loss 0.6862392445914766, validation accuracy 0.5221333333333333\n", "0.7073076963424683 0\n", "0.689601480960846 100\n", "0.5204960703849792 200\n", "0.688227653503418 300\n", "0.6900652647018433 400\n", "0.7155442833900452 500\n", "Epoch 2, training loss 0.6715143832204106, time passed 0m 33s, training accuracy 0.5706285714285714\n", "Validation loss 0.6796312681120685, validation accuracy 0.532\n", "0.6642115116119385 0\n", "0.6051458120346069 100\n", "0.4899091422557831 200\n", "0.47550395131111145 300\n", "0.3884301781654358 400\n", "0.5720760226249695 500\n", "Epoch 3, training loss 0.46659755674037307, time passed 0m 44s, training accuracy 0.7838285714285714\n", "Validation loss 0.44980048197202194, validation accuracy 0.7905333333333333\n", "0.32235172390937805 0\n", "0.35917454957962036 100\n", "0.34756919741630554 200\n", "0.5727543234825134 300\n", "0.39892154932022095 400\n", "0.4784986078739166 500\n", "Epoch 4, training loss 0.3807199837251024, time passed 0m 54s, training accuracy 0.8334857142857143\n", "Validation loss 0.38287904559292346, validation accuracy 0.8326666666666667\n", "0.16310635209083557 0\n", "0.4492853879928589 100\n", "0.3463122248649597 200\n", "0.4847961962223053 300\n", "0.4214150905609131 400\n", "0.3589895963668823 500\n", "Epoch 5, training loss 0.3385137851885605, time passed 1m 5s, training accuracy 0.8572571428571428\n", "Validation loss 0.37860287866021836, validation accuracy 0.8382666666666667\n", "0.3192971348762512 0\n", "0.22000743448734283 100\n", "0.11614805459976196 200\n", "0.2819654941558838 300\n", "0.23107846081256866 400\n", "0.2672703266143799 500\n", "Epoch 6, training loss 0.2946929669778644, time passed 1m 15s, training accuracy 0.8787428571428572\n", "Validation loss 0.39496679317492706, validation accuracy 0.8321333333333333\n", "0.17155088484287262 0\n", "0.21490910649299622 100\n", "0.1092824935913086 200\n", "0.3267301619052887 300\n", "0.47277331352233887 400\n", "0.505771279335022 500\n", "Epoch 7, training loss 0.24134574049995058, time passed 1m 25s, training accuracy 0.9034285714285715\n", "Validation loss 0.39918604534533286, validation accuracy 0.8357333333333333\n", "0.0829261988401413 0\n", "0.16884355247020721 100\n", "0.16323381662368774 200\n", "0.15418237447738647 300\n", "0.2902193069458008 400\n", "0.3000183701515198 500\n", "Epoch 8, training loss 0.18310634673474144, time passed 1m 35s, training accuracy 0.9332\n", "Validation loss 0.44123396850549257, validation accuracy 0.838\n", "0.29574403166770935 0\n", "0.13429410755634308 100\n", "0.10769911110401154 200\n", "0.22999584674835205 300\n", "0.07347916811704636 400\n", "0.23001378774642944 500\n", "Epoch 9, training loss 0.13100098480467068, time passed 1m 46s, training accuracy 0.9544571428571429\n", "Validation loss 0.5024094952222629, validation accuracy 0.8374666666666667\n", "0.06784306466579437 0\n", "0.021445710211992264 100\n", "0.05880201607942581 200\n", "0.11171642690896988 300\n", "0.12890562415122986 400\n", "0.13418814539909363 500\n", "Epoch 10, training loss 0.0893362128813734, time passed 1m 56s, training accuracy 0.9705142857142857\n", "Validation loss 0.6102941697065392, validation accuracy 0.8270666666666666\n", "0.01168014481663704 0\n", "0.1551697701215744 100\n", "0.11678065359592438 200\n", "0.013292469084262848 300\n", "0.07440473139286041 400\n", "0.0031153708696365356 500\n", "Epoch 11, training loss 0.06055153887411886, time passed 2m 6s, training accuracy 0.9810285714285715\n", "Validation loss 0.639088273096161, validation accuracy 0.8266666666666667\n", "0.00817631371319294 0\n", "0.025077704340219498 100\n", "0.013954445719718933 200\n", "0.2211170643568039 300\n", "0.005432415753602982 400\n", "0.02004990726709366 500\n", "Epoch 12, training loss 0.044061114568938756, time passed 2m 17s, training accuracy 0.9877142857142858\n", "Validation loss 0.7138074862683176, validation accuracy 0.83\n", "0.010555963963270187 0\n", "0.006451953202486038 100\n", "0.010996071621775627 200\n", "0.018158936873078346 300\n", "0.013155501335859299 400\n", "0.016464976593852043 500\n", "Epoch 13, training loss 0.03972877365888366, time passed 2m 27s, training accuracy 0.9886857142857143\n", "Validation loss 0.7276951397458712, validation accuracy 0.8386666666666667\n", "0.0073014795780181885 0\n", "0.008466990664601326 100\n", "0.007858745753765106 200\n", "0.0018998309969902039 300\n", "0.09833370894193649 400\n", "0.07120833545923233 500\n", "Epoch 14, training loss 0.03543301275334297, time passed 2m 37s, training accuracy 0.9904571428571428\n", "Validation loss 0.7426877943321298, validation accuracy 0.8288\n", "0.008094564080238342 0\n", "0.004330258816480637 100\n", "0.00565388984978199 200\n", "0.031496673822402954 300\n", "0.01734837517142296 400\n", "0.2173541635274887 500\n", "Epoch 15, training loss 0.030148086574733696, time passed 2m 48s, training accuracy 0.9915428571428572\n", "Validation loss 0.7553343320758934, validation accuracy 0.8204\n", "0.01480589248239994 0\n", "0.0022436529397964478 100\n", "0.002812545746564865 200\n", "0.0078052617609500885 300\n", "0.0025699175894260406 400\n", "0.0029298923909664154 500\n", "Epoch 16, training loss 0.02461542384426072, time passed 2m 58s, training accuracy 0.9931428571428571\n", "Validation loss 0.8719300648086091, validation accuracy 0.8173333333333334\n", "0.004529718309640884 0\n", "0.004372201859951019 100\n", "0.23865947127342224 200\n", "0.47893786430358887 300\n", "0.002506580203771591 400\n", "0.012474816292524338 500\n", "Epoch 17, training loss 0.02630178953329913, time passed 3m 8s, training accuracy 0.9930285714285715\n", "Validation loss 0.9050891679257919, validation accuracy 0.8261333333333334\n", "0.0014899149537086487 0\n", "0.007672000676393509 100\n", "0.0018436536192893982 200\n", "0.011076990514993668 300\n", "0.01749233901500702 400\n", "0.003323737531900406 500\n", "Epoch 18, training loss 0.027579101135155985, time passed 3m 19s, training accuracy 0.9927428571428571\n", "Validation loss 0.9038066414113228, validation accuracy 0.8305333333333333\n", "0.0038200803101062775 0\n", "0.04194427654147148 100\n", "0.002153143286705017 200\n", "0.002913050353527069 300\n", "0.01736719347536564 400\n", "0.04847458750009537 500\n", "Epoch 19, training loss 0.01933261827691953, time passed 3m 29s, training accuracy 0.9957714285714285\n", "Validation loss 0.8214574939547441, validation accuracy 0.8257333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ReBlsVbkAbcq", "colab_type": "code", "outputId": "fd1b6bf5-5d9d-49b7-acbb-5950aef831cc", "colab": { "base_uri": "https://localhost:8080/", "height": 2805 } }, "source": [ "lrs = [0.0001, 0.001]\n", "for lr in lrs:\n", " model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/attn_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.0001, momentum=0. :\n", "547\n", "0.6913896799087524 0\n", "0.6833416819572449 100\n", "0.6810654401779175 200\n", "0.6995826363563538 300\n", "0.6945667862892151 400\n", "0.6890664100646973 500\n", "Epoch 0, training loss 0.6941597445325537, time passed 0m 16s, training accuracy 0.5016571428571429\n", "Validation loss 0.6946967735759213, validation accuracy 0.49093333333333333\n", "0.703650951385498 0\n", "0.6967843174934387 100\n", "0.6908212304115295 200\n", "0.6907294392585754 300\n", "0.6946154832839966 400\n", "0.6943396925926208 500\n", "Epoch 1, training loss 0.693790737307552, time passed 0m 32s, training accuracy 0.5012\n", "Validation loss 0.6941390972361605, validation accuracy 0.4916\n", "0.6794251203536987 0\n", "0.7008486390113831 100\n", "0.6980838775634766 200\n", "0.6835271120071411 300\n", "0.6840112805366516 400\n", "0.696955680847168 500\n", "Epoch 2, training loss 0.6937036286125253, time passed 0m 48s, training accuracy 0.5015428571428572\n", "Validation loss 0.6938381541488517, validation accuracy 0.49173333333333336\n", "0.6919559836387634 0\n", "0.6883322596549988 100\n", "0.695152997970581 200\n", "0.6984624266624451 300\n", "0.6907331943511963 400\n", "0.697288990020752 500\n", "Epoch 3, training loss 0.6936267273766654, time passed 1m 5s, training accuracy 0.5015428571428572\n", "Validation loss 0.6936797244935973, validation accuracy 0.4922666666666667\n", "0.7022473216056824 0\n", "0.6973013877868652 100\n", "0.6953293681144714 200\n", "0.6877068281173706 300\n", "0.697217583656311 400\n", "0.6913156509399414 500\n", "Epoch 4, training loss 0.6935898202024536, time passed 1m 21s, training accuracy 0.5013142857142857\n", "Validation loss 0.6936082987703829, validation accuracy 0.492\n", "0.6885882019996643 0\n", "0.6926128268241882 100\n", "0.6901581883430481 200\n", "0.6929795145988464 300\n", "0.6940913796424866 400\n", "0.6981614828109741 500\n", "Epoch 5, training loss 0.6935964603345472, time passed 1m 37s, training accuracy 0.5013714285714286\n", "Validation loss 0.6935551594465207, validation accuracy 0.4918666666666667\n", "0.6927429437637329 0\n", "0.6986611485481262 100\n", "0.6948174834251404 200\n", "0.6938021779060364 300\n", "0.6924692988395691 400\n", "0.694472074508667 500\n", "Epoch 6, training loss 0.6935844638627091, time passed 1m 53s, training accuracy 0.5014285714285714\n", "Validation loss 0.6935180904518845, validation accuracy 0.4924\n", "0.6893236637115479 0\n", "0.6937673687934875 100\n", "0.6925515532493591 200\n", "0.6866517663002014 300\n", "0.6890139579772949 400\n", "0.6973516941070557 500\n", "Epoch 7, training loss 0.6935794655874972, time passed 2m 10s, training accuracy 0.5014857142857143\n", "Validation loss 0.693523046043184, validation accuracy 0.4925333333333333\n", "0.7042493224143982 0\n", "0.686928391456604 100\n", "0.6926462054252625 200\n", "0.6908868551254272 300\n", "0.6884534358978271 400\n", "0.6849415302276611 500\n", "Epoch 8, training loss 0.6935443760274531, time passed 2m 26s, training accuracy 0.5013714285714286\n", "Validation loss 0.6935278755477351, validation accuracy 0.49266666666666664\n", "0.6919976472854614 0\n", "0.6861199140548706 100\n", "0.701988697052002 200\n", "0.6915175318717957 300\n", "0.6888453960418701 400\n", "0.6932827234268188 500\n", "Epoch 9, training loss 0.6935574885034735, time passed 2m 42s, training accuracy 0.5014285714285714\n", "Validation loss 0.6935218911395113, validation accuracy 0.4925333333333333\n", "SGD lr=0.001, momentum=0. :\n", "547\n", "0.6962258815765381 0\n", "0.6885945200920105 100\n", "0.6863468885421753 200\n", "0.6919173002243042 300\n", "0.6898753046989441 400\n", "0.6896785497665405 500\n", "Epoch 0, training loss 0.693063280412129, time passed 0m 16s, training accuracy 0.5017142857142857\n", "Validation loss 0.6927863744079558, validation accuracy 0.5182666666666667\n", "0.6939694881439209 0\n", "0.6943249702453613 100\n", "0.6949176788330078 200\n", "0.6970770955085754 300\n", "0.6976656317710876 400\n", "0.7010515332221985 500\n", "Epoch 1, training loss 0.6929705727013039, time passed 0m 32s, training accuracy 0.4995428571428571\n", "Validation loss 0.6932998309787523, validation accuracy 0.494\n", "0.6950815916061401 0\n", "0.6884225010871887 100\n", "0.69386225938797 200\n", "0.6936244964599609 300\n", "0.694794237613678 400\n", "0.6864222288131714 500\n", "Epoch 2, training loss 0.6929141268843696, time passed 0m 48s, training accuracy 0.4964571428571429\n", "Validation loss 0.6933635147208841, validation accuracy 0.49466666666666664\n", "0.6965394616127014 0\n", "0.6961629390716553 100\n", "0.6910871863365173 200\n", "0.6860278248786926 300\n", "0.702141284942627 400\n", "0.6921371221542358 500\n", "Epoch 3, training loss 0.6928582324649825, time passed 1m 5s, training accuracy 0.5032\n", "Validation loss 0.692806190914578, validation accuracy 0.49\n", "0.6952767372131348 0\n", "0.6889387965202332 100\n", "0.6943261623382568 200\n", "0.6909245848655701 300\n", "0.6896691918373108 400\n", "0.6901578903198242 500\n", "Epoch 4, training loss 0.6927566896209787, time passed 1m 21s, training accuracy 0.5070285714285714\n", "Validation loss 0.6925799474756942, validation accuracy 0.5138666666666667\n", "0.6921440958976746 0\n", "0.6886021494865417 100\n", "0.6928223371505737 200\n", "0.6882325410842896 300\n", "0.6940441131591797 400\n", "0.6978073120117188 500\n", "Epoch 5, training loss 0.6927209747579944, time passed 1m 37s, training accuracy 0.5090857142857143\n", "Validation loss 0.6923715749866942, validation accuracy 0.5154666666666666\n", "0.6963284015655518 0\n", "0.6894926428794861 100\n", "0.6877952814102173 200\n", "0.6923049092292786 300\n", "0.6969224214553833 400\n", "0.6973981857299805 500\n", "Epoch 6, training loss 0.6926383953609746, time passed 1m 54s, training accuracy 0.5088571428571429\n", "Validation loss 0.6927720010280609, validation accuracy 0.49466666666666664\n", "0.6923703551292419 0\n", "0.6939659714698792 100\n", "0.6935223937034607 200\n", "0.6979855895042419 300\n", "0.6856727600097656 400\n", "0.7023521661758423 500\n", "Epoch 7, training loss 0.6925406037887811, time passed 2m 10s, training accuracy 0.51\n", "Validation loss 0.6933551822972094, validation accuracy 0.49506666666666665\n", "0.6942978501319885 0\n", "0.6977471709251404 100\n", "0.6959917545318604 200\n", "0.6951767206192017 300\n", "0.6951934695243835 400\n", "0.6880986094474792 500\n", "Epoch 8, training loss 0.6924885011417962, time passed 2m 26s, training accuracy 0.5107428571428572\n", "Validation loss 0.6927055528021266, validation accuracy 0.4961333333333333\n", "0.6947516798973083 0\n", "0.6958856582641602 100\n", "0.6986777186393738 200\n", "0.6878237724304199 300\n", "0.6968024373054504 400\n", "0.6909596920013428 500\n", "Epoch 9, training loss 0.6925231927917117, time passed 2m 42s, training accuracy 0.5066857142857143\n", "Validation loss 0.6926503782598381, validation accuracy 0.4965333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "rUTCEYbXA-ni", "colab_type": "code", "outputId": "ddbbcf1e-a41a-44fc-e2fc-02be7c4e819f", "colab": { "base_uri": "https://localhost:8080/", "height": 2825 } }, "source": [ "l_0 = 20\n", "model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", "print(\"SUG l_0={}, momentum=0. :\".format(l_0))\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", "torch.save(states, './IMDB/attn_sug')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SUG l_0=20, momentum=0. :\n", "547\n", "0.6829845309257507 0\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:522: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n", " self.dropout, self.training, self.bidirectional, self.batch_first)\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "0.7011003494262695 100\n", "0.6893718242645264 200\n", "0.7490146160125732 300\n", "0.6662403345108032 400\n", "0.7302882075309753 500\n", "Epoch 0, training loss 0.7087923813433874, time passed 0m 34s, training accuracy 0.524\n", "Validation loss 0.6786875987154806, validation accuracy 0.5837333333333333\n", "0.6089674234390259 0\n", "0.5822049975395203 100\n", "0.5361784100532532 200\n", "0.48254698514938354 300\n", "0.5182648301124573 400\n", "0.5892543792724609 500\n", "Epoch 1, training loss 0.5938645527585522, time passed 1m 8s, training accuracy 0.6910285714285714\n", "Validation loss 0.5369184652709553, validation accuracy 0.7398666666666667\n", "0.4637317359447479 0\n", "0.49282562732696533 100\n", "0.7132331728935242 200\n", "0.37981486320495605 300\n", "0.4586241543292999 400\n", "0.5515133142471313 500\n", "Epoch 2, training loss 0.473880006089097, time passed 1m 43s, training accuracy 0.7809142857142857\n", "Validation loss 0.5159184689450468, validation accuracy 0.7656\n", "0.4248925447463989 0\n", "0.2927207946777344 100\n", "0.3366227447986603 200\n", "0.4807620048522949 300\n", "0.45819780230522156 400\n", "0.3322904109954834 500\n", "Epoch 3, training loss 0.3931167991388412, time passed 2m 18s, training accuracy 0.8290285714285714\n", "Validation loss 0.4972761335153865, validation accuracy 0.7766666666666666\n", "0.29197123646736145 0\n", "0.38541924953460693 100\n", "0.3892315924167633 200\n", "0.25138404965400696 300\n", "0.30947819352149963 400\n", "0.3718239665031433 500\n", "Epoch 4, training loss 0.34823366385567317, time passed 2m 53s, training accuracy 0.8515428571428572\n", "Validation loss 0.5094487400900605, validation accuracy 0.7744\n", "0.3261384665966034 0\n", "0.36068665981292725 100\n", "0.2923737168312073 200\n", "0.13714373111724854 300\n", "0.3290799558162689 400\n", "0.2737971544265747 500\n", "Epoch 5, training loss 0.29837472485071354, time passed 3m 27s, training accuracy 0.8736\n", "Validation loss 0.5119595573498652, validation accuracy 0.7909333333333334\n", "0.3927178680896759 0\n", "0.38204678893089294 100\n", "0.5386406779289246 200\n", "0.2763158082962036 300\n", "0.2710512578487396 400\n", "0.1905616968870163 500\n", "Epoch 6, training loss 0.26174425519320554, time passed 4m 2s, training accuracy 0.8938857142857143\n", "Validation loss 0.5488016199734476, validation accuracy 0.7885333333333333\n", "0.1985805332660675 0\n", "0.1362437903881073 100\n", "0.20704106986522675 200\n", "0.28067371249198914 300\n", "0.3431782126426697 400\n", "0.1644502878189087 500\n", "Epoch 7, training loss 0.22714028708063638, time passed 4m 37s, training accuracy 0.9111428571428571\n", "Validation loss 0.5174612025292511, validation accuracy 0.7974666666666667\n", "0.2646089792251587 0\n", "0.09820486605167389 100\n", "0.14298667013645172 200\n", "0.1455533504486084 300\n", "0.1463765650987625 400\n", "0.11780868470668793 500\n", "Epoch 8, training loss 0.2043923422926184, time passed 5m 12s, training accuracy 0.9220571428571429\n", "Validation loss 0.5430883773180664, validation accuracy 0.7901333333333334\n", "0.14028291404247284 0\n", "0.09407615661621094 100\n", "0.16044829785823822 200\n", "0.2547101080417633 300\n", "0.17609679698944092 400\n", "0.19136053323745728 500\n", "Epoch 9, training loss 0.18234535001027277, time passed 5m 46s, training accuracy 0.9312\n", "Validation loss 0.55573059236392, validation accuracy 0.8005333333333333\n", "0.12586961686611176 0\n", "0.20234878361225128 100\n", "0.297412633895874 200\n", "0.1123238354921341 300\n", "0.22616350650787354 400\n", "0.1925792396068573 500\n", "Epoch 10, training loss 0.16486637671938636, time passed 6m 21s, training accuracy 0.9405714285714286\n", "Validation loss 0.560932617698215, validation accuracy 0.7988\n", "0.06871110945940018 0\n", "0.275569885969162 100\n", "0.22198261320590973 200\n", "0.1415991187095642 300\n", "0.057735737413167953 400\n", "0.2961047291755676 500\n", "Epoch 11, training loss 0.15006382818309924, time passed 6m 56s, training accuracy 0.9462857142857143\n", "Validation loss 0.5667134043879998, validation accuracy 0.8006666666666666\n", "0.11106912791728973 0\n", "0.2575933635234833 100\n", "0.2693217992782593 200\n", "0.17979584634304047 300\n", "0.06759869307279587 400\n", "0.1553390920162201 500\n", "Epoch 12, training loss 0.13279787350732547, time passed 7m 31s, training accuracy 0.9547428571428571\n", "Validation loss 0.5974068640070593, validation accuracy 0.8008\n", "0.058201633393764496 0\n", "0.05967959761619568 100\n", "0.07278057932853699 200\n", "0.13669390976428986 300\n", "0.03568866103887558 400\n", "0.0687827616930008 500\n", "Epoch 13, training loss 0.12421009114401026, time passed 8m 5s, training accuracy 0.958\n", "Validation loss 0.6206559310547817, validation accuracy 0.8026666666666666\n", "0.06015688180923462 0\n", "0.10720789432525635 100\n", "0.06766632944345474 200\n", "0.05103915557265282 300\n", "0.04768899083137512 400\n", "0.15273821353912354 500\n", "Epoch 14, training loss 0.11255109655387673, time passed 8m 40s, training accuracy 0.9616\n", "Validation loss 0.6423081847019175, validation accuracy 0.8014666666666667\n", "0.04458184167742729 0\n", "0.05341879650950432 100\n", "0.19494415819644928 200\n", "0.016109846532344818 300\n", "0.15212471783161163 400\n", "0.06866186112165451 500\n", "Epoch 15, training loss 0.10480376769832912, time passed 9m 15s, training accuracy 0.9664571428571429\n", "Validation loss 0.659027116229901, validation accuracy 0.8014666666666667\n", "0.044951993972063065 0\n", "0.07728096097707748 100\n", "0.03233599290251732 200\n", "0.1555192917585373 300\n", "0.11952584981918335 400\n", "0.06148263067007065 500\n", "Epoch 16, training loss 0.09644550548170458, time passed 9m 50s, training accuracy 0.9702285714285714\n", "Validation loss 0.6640930891864829, validation accuracy 0.7996\n", "0.06904051452875137 0\n", "0.06091935560107231 100\n", "0.11497051268815994 200\n", "0.09069794416427612 300\n", "0.0461907759308815 400\n", "0.05993930250406265 500\n", "Epoch 17, training loss 0.09394364140334201, time passed 10m 24s, training accuracy 0.9712\n", "Validation loss 0.6865501854027438, validation accuracy 0.8001333333333334\n", "0.10262757539749146 0\n", "0.05460500717163086 100\n", "0.15111148357391357 200\n", "0.022144455462694168 300\n", "0.04337551072239876 400\n", "0.23838751018047333 500\n", "Epoch 18, training loss 0.0841821710350531, time passed 10m 59s, training accuracy 0.9744571428571429\n", "Validation loss 0.7002959495776484, validation accuracy 0.8001333333333334\n", "0.03549156337976456 0\n", "0.0709724947810173 100\n", "0.08607187867164612 200\n", "0.030010128393769264 300\n", "0.07841354608535767 400\n", "0.06701678037643433 500\n", "Epoch 19, training loss 0.07890484321541784, time passed 11m 34s, training accuracy 0.9757714285714286\n", "Validation loss 0.7400413659265918, validation accuracy 0.7984\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ZFQVb78DVh32", "colab_type": "code", "outputId": "9fb5725f-6110-42e3-bd45-09eaf51112fc", "colab": { "base_uri": "https://localhost:8080/", "height": 2771 } }, "source": [ "lrs = [0.0001]\n", "for lr in lrs:\n", " model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.Adam(model.parameters(), lr=lr)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/attn_adam_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.0001, momentum=0. :\n", "547\n", "0.7074708938598633 0\n", "0.699157178401947 100\n", "0.6875941157341003 200\n", "0.6936682462692261 300\n", "0.6861938238143921 400\n", "0.693611741065979 500\n", "Epoch 0, training loss 0.6934679332888607, time passed 0m 25s, training accuracy 0.5128571428571429\n", "Validation loss 0.6908549699518416, validation accuracy 0.5306666666666666\n", "0.6809442043304443 0\n", "0.7160238027572632 100\n", "0.7025774121284485 200\n", "0.6726558804512024 300\n", "0.6648744940757751 400\n", "0.7097936868667603 500\n", "Epoch 1, training loss 0.6831305411272434, time passed 0m 52s, training accuracy 0.548\n", "Validation loss 0.6688847951909416, validation accuracy 0.5654666666666667\n", "0.7096719145774841 0\n", "0.6852087378501892 100\n", "0.6141272783279419 200\n", "0.5822790265083313 300\n", "0.5733252763748169 400\n", "0.4738878607749939 500\n", "Epoch 2, training loss 0.6086636586831167, time passed 1m 18s, training accuracy 0.6895428571428571\n", "Validation loss 0.5956426103655089, validation accuracy 0.7068\n", "0.572991132736206 0\n", "0.5878808498382568 100\n", "0.39202553033828735 200\n", "0.45941799879074097 300\n", "0.37650367617607117 400\n", "0.317059725522995 500\n", "Epoch 3, training loss 0.4820362525708946, time passed 1m 44s, training accuracy 0.7854857142857142\n", "Validation loss 0.4926339803088425, validation accuracy 0.7690666666666667\n", "0.41018664836883545 0\n", "0.4187754690647125 100\n", "0.4126662015914917 200\n", "0.3446335792541504 300\n", "0.3422461450099945 400\n", "0.6362811326980591 500\n", "Epoch 4, training loss 0.39267630127323416, time passed 2m 10s, training accuracy 0.8334285714285714\n", "Validation loss 0.4681197350096499, validation accuracy 0.7885333333333333\n", "0.35670006275177 0\n", "0.25960102677345276 100\n", "0.4475996494293213 200\n", "0.29875659942626953 300\n", "0.5717836618423462 400\n", "0.34494999051094055 500\n", "Epoch 5, training loss 0.3166446594637392, time passed 2m 37s, training accuracy 0.8741142857142857\n", "Validation loss 0.499808596017269, validation accuracy 0.8038666666666666\n", "0.2251623421907425 0\n", "0.16197097301483154 100\n", "0.17855232954025269 200\n", "0.29577791690826416 300\n", "0.3582856059074402 400\n", "0.35170048475265503 500\n", "Epoch 6, training loss 0.2593910374470574, time passed 3m 3s, training accuracy 0.9002285714285714\n", "Validation loss 0.4939273253847391, validation accuracy 0.8026666666666666\n", "0.18230752646923065 0\n", "0.047517504543066025 100\n", "0.45938828587532043 200\n", "0.1326374113559723 300\n", "0.2861888110637665 400\n", "0.13111262023448944 500\n", "Epoch 7, training loss 0.20957574446177307, time passed 3m 29s, training accuracy 0.9256571428571428\n", "Validation loss 0.47968957894760317, validation accuracy 0.8205333333333333\n", "0.21449321508407593 0\n", "0.12550890445709229 100\n", "0.20170322060585022 200\n", "0.0526358000934124 300\n", "0.37328943610191345 400\n", "0.14322389662265778 500\n", "Epoch 8, training loss 0.17157137170502226, time passed 3m 55s, training accuracy 0.9419428571428572\n", "Validation loss 0.48979064755332774, validation accuracy 0.8244\n", "0.03937157243490219 0\n", "0.052820835262537 100\n", "0.11523545533418655 200\n", "0.16278201341629028 300\n", "0.13283699750900269 400\n", "0.03738832846283913 500\n", "Epoch 9, training loss 0.1419193714399969, time passed 4m 22s, training accuracy 0.9532\n", "Validation loss 0.4799800062408814, validation accuracy 0.818\n", "0.06695418059825897 0\n", "0.03107781521975994 100\n", "0.02595672756433487 200\n", "0.044095396995544434 300\n", "0.023531656712293625 400\n", "0.13935334980487823 500\n", "Epoch 10, training loss 0.11386594509504626, time passed 4m 49s, training accuracy 0.9652\n", "Validation loss 0.5255547962191268, validation accuracy 0.8145333333333333\n", "0.07934461534023285 0\n", "0.09646227210760117 100\n", "0.053853631019592285 200\n", "0.051237765699625015 300\n", "0.055733755230903625 400\n", "0.0598265565931797 500\n", "Epoch 11, training loss 0.08856902625616435, time passed 5m 16s, training accuracy 0.9737142857142858\n", "Validation loss 0.5335595943988898, validation accuracy 0.808\n", "0.22429613769054413 0\n", "0.013865365646779537 100\n", "0.1337713897228241 200\n", "0.032364871352910995 300\n", "0.034993696957826614 400\n", "0.019271334633231163 500\n", "Epoch 12, training loss 0.06800744768433382, time passed 5m 43s, training accuracy 0.9802857142857143\n", "Validation loss 0.5118081265757991, validation accuracy 0.8109333333333333\n", "0.059634629637002945 0\n", "0.010468564927577972 100\n", "0.0453047901391983 200\n", "0.009559735655784607 300\n", "0.044167228043079376 400\n", "0.2211085706949234 500\n", "Epoch 13, training loss 0.06178751782831419, time passed 6m 9s, training accuracy 0.9830285714285715\n", "Validation loss 0.5295347827367294, validation accuracy 0.7632\n", "0.2164592295885086 0\n", "0.00793084129691124 100\n", "0.07709269225597382 200\n", "0.1555565893650055 300\n", "0.005646074656397104 400\n", "0.0027028685435652733 500\n", "Epoch 14, training loss 0.058877421620612345, time passed 6m 35s, training accuracy 0.9822857142857143\n", "Validation loss 0.7158146076477491, validation accuracy 0.8157333333333333\n", "0.009047086350619793 0\n", "0.09421178698539734 100\n", "0.23725338280200958 200\n", "0.003965259063988924 300\n", "0.020767798647284508 400\n", "0.006656785029917955 500\n", "Epoch 15, training loss 0.0380694701659558, time passed 7m 1s, training accuracy 0.9896\n", "Validation loss 0.5929487566981051, validation accuracy 0.8121333333333334\n", "0.03259653225541115 0\n", "0.0019155398476868868 100\n", "0.0016652056947350502 200\n", "0.020231522619724274 300\n", "0.004920525010675192 400\n", "0.0324997641146183 500\n", "Epoch 16, training loss 0.03166716295879099, time passed 7m 27s, training accuracy 0.9913714285714286\n", "Validation loss 0.7411979428914368, validation accuracy 0.8037333333333333\n", "0.03442881256341934 0\n", "0.10267721861600876 100\n", "0.050596438348293304 200\n", "0.01020444743335247 300\n", "0.005791102536022663 400\n", "0.022519394755363464 500\n", "Epoch 17, training loss 0.028926685650857684, time passed 7m 54s, training accuracy 0.9909142857142857\n", "Validation loss 0.7639160853627528, validation accuracy 0.82\n", "0.005129156168550253 0\n", "0.025728877633810043 100\n", "0.0014372519217431545 200\n", "0.001325768418610096 300\n", "0.03210834041237831 400\n", "0.042594511061906815 500\n", "Epoch 18, training loss 0.021579638746419525, time passed 8m 20s, training accuracy 0.9945714285714286\n", "Validation loss 0.9954519265691519, validation accuracy 0.8138666666666666\n", "0.045195505023002625 0\n", "0.006256254855543375 100\n", "0.002010277472436428 200\n", "0.01034162100404501 300\n", "0.0030539222061634064 400\n", "0.009373277425765991 500\n", "Epoch 19, training loss 0.02113530602935751, time passed 8m 46s, training accuracy 0.994\n", "Validation loss 0.7144950056114258, validation accuracy 0.8089333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "fUk_a6WwZokj", "colab_type": "code", "outputId": "6f48a4ff-471a-4b7e-e5b4-15cc249397d5", "colab": { "base_uri": "https://localhost:8080/", "height": 1465 } }, "source": [ "l_0 = 20\n", "model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", "print(\"SUG l_0={}, momentum=0. :\".format(l_0))\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.9, weight_decay=1e-3)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", "torch.save(states, './IMDB/attn_sug_0.9')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SUG l_0=20, momentum=0. :\n", "547\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:522: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n", " self.dropout, self.training, self.bidirectional, self.batch_first)\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "0.6980208158493042 0\n", "0.6874722242355347 100\n", "0.6817547678947449 200\n", "0.6828298568725586 300\n", "0.7083317637443542 400\n", "0.7124193906784058 500\n", "Epoch 0, training loss 0.6966337240659274, time passed 0m 40s, training accuracy 0.4993714285714286\n", "Validation loss 0.6994800914047111, validation accuracy 0.4866666666666667\n", "0.6685550808906555 0\n", "0.6823179721832275 100\n", "0.7071587443351746 200\n", "0.6757057905197144 300\n", "0.6761490702629089 400\n", "0.6832512617111206 500\n", "Epoch 1, training loss 0.6966324601417933, time passed 1m 21s, training accuracy 0.4993142857142857\n", "Validation loss 0.6993934662423582, validation accuracy 0.4866666666666667\n", "0.7265142202377319 0\n", "0.6936051845550537 100\n", "0.6723883152008057 200\n", "0.6989526748657227 300\n", "0.6900702118873596 400\n", "0.7373038530349731 500\n", "Epoch 2, training loss 0.696468867006756, time passed 2m 1s, training accuracy 0.49925714285714284\n", "Validation loss 0.6993084137256329, validation accuracy 0.4868\n", "0.6769171357154846 0\n", "0.6965464949607849 100\n", "0.7059141397476196 200\n", "0.7059294581413269 300\n", "0.6788325309753418 400\n", "0.7125336527824402 500\n", "Epoch 3, training loss 0.6965000682916397, time passed 2m 42s, training accuracy 0.4993142857142857\n", "Validation loss 0.6992231469887954, validation accuracy 0.48693333333333333\n", "0.7085320353507996 0\n", "0.6722798347473145 100\n", "0.6801241636276245 200\n", "0.7182485461235046 300\n", "0.6956546902656555 400\n", "0.7188464999198914 500\n", "Epoch 4, training loss 0.6963817277452448, time passed 3m 22s, training accuracy 0.49925714285714284\n", "Validation loss 0.6991419461038377, validation accuracy 0.4868\n", "0.701016366481781 0\n", "0.688944935798645 100\n", "0.6913262605667114 200\n", "0.7181053161621094 300\n", "0.6912066340446472 400\n", "0.72074294090271 500\n", "Epoch 5, training loss 0.6963420041969844, time passed 4m 2s, training accuracy 0.4992\n", "Validation loss 0.6990588495873997, validation accuracy 0.4868\n", "0.6910022497177124 0\n", "0.6851209998130798 100\n", "0.6807781457901001 200\n", "0.6992347240447998 300\n", "0.6912928819656372 400\n", "0.6885637044906616 500\n", "Epoch 6, training loss 0.6963040607097821, time passed 4m 43s, training accuracy 0.4992\n", "Validation loss 0.6989802030416635, validation accuracy 0.4868\n", "0.6948328018188477 0\n", "0.7066507339477539 100\n", "0.6868513822555542 200\n", "0.7039141058921814 300\n", "0.6959588527679443 400\n", "0.6868894100189209 500\n", "Epoch 7, training loss 0.6962397077799717, time passed 5m 23s, training accuracy 0.4992\n", "Validation loss 0.698900478772628, validation accuracy 0.4866666666666667\n", "0.6993783116340637 0\n", "0.6839303970336914 100\n", "0.6700723171234131 200\n", "0.6798321604728699 300\n", "0.7009574770927429 400\n", "0.6968734264373779 500\n", "Epoch 8, training loss 0.6961875306817639, time passed 6m 4s, training accuracy 0.4993142857142857\n", "Validation loss 0.6988244420952268, validation accuracy 0.4865333333333333\n", "0.6909798383712769 0\n", "0.7114906311035156 100\n", "0.6953390836715698 200\n", "0.6998339295387268 300\n", "0.680980920791626 400\n", "0.6675326228141785 500\n", "Epoch 9, training loss 0.6961533421780164, time passed 6m 45s, training accuracy 0.4993714285714286\n", "Validation loss 0.6987498071458604, validation accuracy 0.4865333333333333\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "zD5yusoifTzl", "colab": {} }, "source": [ "l_0 = 20\n", "model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", "print(\"SUG l_0={}, momentum=0. :\".format(l_0))\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.5, weight_decay=0.)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", "torch.save(states, './IMDB/attn_sug_0.5_wd_1e-4')" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ysPeP6iafaVV", "colab_type": "code", "colab": {} }, "source": [ "lrs = [0.0001]\n", "for lr in lrs:\n", " model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings)\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'tr_acc' : tr_acc,\n", " 'val_acc' : val_acc\n", " }\n", " torch.save(states, './IMDB/atnn_adam_' + str(lr)+'_wd_1e-4')" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "crBTTMf6lnvB", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }