{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "MNIST.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [], "toc_visible": true, "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "_ni0nMDcukyV", "outputId": "1bd14500-985c-45c9-917c-feef3cf0efdb", "colab": { "base_uri": "https://localhost:8080/", "height": 54 } }, "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "p-4vBJwaukwk", "colab": {} }, "source": [ "import os\n", "os.chdir(\"/content/drive/My Drive/Colab Notebooks/Optimization project\")\n", "os.getcwd()\n", "\n", "file_path = \"./MNIST\"\n", "\n", "try:\n", " os.stat(file_path)\n", "except:\n", " os.mkdir(file_path) " ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "QzSdhyDLASJJ", "colab_type": "code", "colab": {} }, "source": [ "from sug import *" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "colab_type": "code", "id": "pA8A5BaJ-1-A", "colab": {} }, "source": [ "import torch\n", "from torch.optim import Optimizer\n", "import math\n", "import copy\n", "\n", "class SUG(Optimizer):\n", " def __init__(self, params, l_0, d_0=0, prob=1., eps=1e-4, momentum=0, dampening=0,\n", " weight_decay=0, nesterov=False):\n", " if l_0 < 0.0:\n", " raise ValueError(\"Invalid Lipsitz constant of gradient: {}\".format(l_0))\n", " if d_0 < 0.0:\n", " raise ValueError(\"Invalid disperion of gradient: {}\".format(d_0))\n", " if momentum < 0.0:\n", " raise ValueError(\"Invalid momentum value: {}\".format(momentum))\n", " if weight_decay < 0.0:\n", " raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n", "\n", " defaults = dict(L=l_0, momentum=momentum, dampening=dampening,\n", " weight_decay=weight_decay, nesterov=nesterov)\n", " if nesterov and (momentum <= 0 or dampening != 0):\n", " raise ValueError(\"Nesterov momentum requires a momentum and zero dampening\")\n", " self.Lips = l_0\n", " self.prev_Lips = l_0\n", " self.D_0 = d_0\n", " self.eps = eps\n", " self.prob = prob\n", " self.start_param = params\n", " self.upd_sq_grad_norm = None\n", " self.sq_grad_norm = None\n", " self.loss = torch.tensor(0.)\n", " self.cur_loss = 0\n", " self.closure = None\n", " super(SUG, self).__init__(params, defaults)\n", "\n", " def __setstate__(self, state):\n", " super(SUG, self).__setstate__(state)\n", " for group in self.param_groups:\n", " group.setdefault('nesterov', False)\n", "\n", " def comp_batch_size(self):\n", " \"\"\"Returns optimal batch size for given d_0, eps and l_0;\n", "\n", " \"\"\"\n", " return math.ceil(2 * self.D_0 * self.eps / self.prev_Lips)\n", "\n", " def step(self, loss, closure):\n", " \"\"\"Performs a single optimization step.\n", "\n", " Arguments:\n", " loss : current loss\n", "\n", " closure (callable, optional): A closure that reevaluates the model\n", " and returns the loss.\n", " \"\"\"\n", " self.start_params = []\n", " self.loss = loss\n", " self.sq_grad_norm = 0\n", " self.cur_loss = loss\n", " self.closure = closure\n", " for gr_idx, group in enumerate(self.param_groups):\n", " weight_decay = group['weight_decay']\n", " momentum = group['momentum']\n", " dampening = group['dampening']\n", " nesterov = group['nesterov']\n", " self.start_params.append([])\n", " for p_idx, p in enumerate(group['params']):\n", " self.start_params[gr_idx].append([p.data.clone()])\n", " if p.grad is None:\n", " continue\n", " self.start_params[gr_idx][p_idx].append(p.grad.data.clone())\n", " d_p = self.start_params[gr_idx][p_idx][1]\n", " p_ = self.start_params[gr_idx][p_idx][0]\n", " \n", " \n", " if weight_decay != 0:\n", " d_p.add_(weight_decay, p.data)\n", " self.cur_loss += weight_decay * torch.sum(p * p).item()\n", " \n", " \n", " self.sq_grad_norm += torch.sum(d_p * d_p).item()\n", " \n", " if momentum != 0:\n", " param_state = self.state[p]\n", " if 'momentum_buffer' not in param_state:\n", " buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)\n", " buf.mul_(momentum).add_(d_p)\n", " else:\n", " buf = param_state['momentum_buffer']\n", " buf.mul_(momentum).add_(1 - dampening, d_p)\n", " if nesterov:\n", " d_p = d_p.add(momentum, buf)\n", " else:\n", " d_p = buf\n", " self.start_params[gr_idx][p_idx][1] = d_p\n", " \n", " i = 0\n", " self.Lips = max(self.prev_Lips / 2, 0.1)\n", " difference = -1\n", " while difference < 0 or i == 0:\n", " if (i > 0): \n", " self.Lips = max(self.Lips * 2, 0.1)\n", " for gr_idx, group in enumerate(self.param_groups):\n", " for p_idx, p in enumerate(group['params']):\n", " if p.grad is None:\n", " continue\n", " start_param_val = self.start_params[gr_idx][p_idx][0]\n", " start_param_grad = self.start_params[gr_idx][p_idx][1]\n", " p.data = start_param_val - 1/(2*self.Lips) * start_param_grad\n", " difference, upd_loss = self.stop_criteria()\n", " i += 1\n", " self.prev_Lips = self.Lips\n", "\n", " return self.Lips, i\n", "\n", " def stop_criteria(self):\n", " \"\"\"Checks if the Lipsitz constant of gradient is appropriate\n", " \n", " + 2L_k / 2 ||x_k - w_k||^2 = - 1 / (2L_k)||g(x_k)||^2 + 1 / (4L_k)||g(x_k)||^2 = -1 / (4L_k)||g(x_k)||^2 \n", " \"\"\"\n", " upd_loss = self.closure()\n", " major = self.cur_loss - 1 / (4 * self.Lips) * self.sq_grad_norm\n", " return major - upd_loss - self.l2_reg() + self.eps / 10, upd_loss\n", "\n", " def get_lipsitz_const(self):\n", " \"\"\"Returns current Lipsitz constant of the gradient of the loss function\n", " \"\"\"\n", " return self.Lips\n", " \n", " def get_sq_grad(self):\n", " \"\"\"Returns the current second norm of the gradient of the loss function \n", " calculated by the formula\n", " \n", " ||f'(p_1,...,p_n)||_2^2 ~ \\sum\\limits_{i=1}^n ((df/dp_i) * (df/dp_i))(p1,...,p_n))\n", " \n", " \"\"\"\n", " self.upd_sq_grad_norm = 0\n", " for gr_idx, group in enumerate(self.param_groups):\n", " for p_idx, p in enumerate(group['params']):\n", " if p.grad is None:\n", " continue\n", " self.upd_sq_grad_norm += torch.sum(p.grad.data * p.grad.data).item()\n", " \n", " return self.upd_sq_grad_norm\n", " \n", " def l2_reg(self):\n", " \"\"\"Returns the current l2 regularization addiction\n", " \n", " \"\"\"\n", " self.upd_l2_reg = 0\n", " for gr_idx, group in enumerate(self.param_groups):\n", " weight_decay = group['weight_decay']\n", " if weight_decay != 0:\n", " for p_idx, p in enumerate(group['params']):\n", " self.upd_l2_reg += weight_decay * torch.sum(p * p).item()\n", " \n", " return self.upd_l2_reg" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "fiSHwkvZumRd", "colab_type": "code", "colab": {} }, "source": [ "%matplotlib inline\n", "import torch\n", "from torch import nn\n", "from torch import optim\n", "from torch.autograd import Variable\n", "from torch.utils.data.sampler import SubsetRandomSampler\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import time" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "p2Mo1UdyumT9", "colab_type": "code", "outputId": "8f2d99ea-e075-4ad9-b6ad-d9bc8673d29a", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "device" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "metadata": { "tags": [] }, "execution_count": 6 } ] }, { "cell_type": "markdown", "metadata": { "id": "W6Lk5gvGu1Yc", "colab_type": "text" }, "source": [ "## Data" ] }, { "cell_type": "code", "metadata": { "id": "aSGMNymTW5L8", "colab_type": "code", "colab": {} }, "source": [ "batch_size = 512" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "_rk3Q_Jku2by", "colab_type": "code", "outputId": "99abf9da-29f3-46f0-92f3-329b008bb627", "colab": { "base_uri": "https://localhost:8080/", "height": 275 } }, "source": [ "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))])\n", "\n", "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n", " download=True, transform=transform)\n", "\n", "valid_dataset = torchvision.datasets.MNIST(root='/data', train=True, \n", " download=True, transform=transform)\n", "\n", "testset = torchvision.datasets.MNIST(root='./data', train=False,\n", " download=True, transform=transform)\n", "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", " shuffle=False, num_workers=2)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "\r0it [00:00, ?it/s]" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/MNIST/raw/train-images-idx3-ubyte.gz\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "9920512it [00:01, 8639858.13it/s] \n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Extracting /data/MNIST/raw/train-images-idx3-ubyte.gz\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ " 0%| | 0/28881 [00:00" ] }, "metadata": { "tags": [] }, "execution_count": 16 } ] }, { "cell_type": "code", "metadata": { "id": "B6jCFM66X7ql", "colab_type": "code", "colab": {} }, "source": [ "def concat_states(state1, state2):\n", " states = {\n", " 'epoch': state1['epoch'] + state2['epoch'],\n", " 'state_dict': state2['state_dict'],\n", " 'optimizer': state2['optimizer'],\n", " 'tr_loss' : state1['tr_loss'] + state2['tr_loss'],\n", " 'val_loss' : state1['val_loss'] + state2['val_loss'],\n", " 'lips' : state1['lips'] + state2['lips'],\n", " 'grad' : state1['grad'] + state2['grad'],\n", " #'times' : state1['times'] + list(map(lambda x: x + state1['times'][-1],state2['times']))\n", " 'times' : state1['times'] + state2['times']\n", " }\n", " return states" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "98EHtawSzoTM", "colab_type": "text" }, "source": [ "###LR" ] }, { "cell_type": "code", "metadata": { "id": "cCXe24kKE1Vt", "colab_type": "code", "colab": {} }, "source": [ "n_epochs = 20" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "hTMWs2dYznyO", "colab_type": "code", "outputId": "948ebeeb-cc92-46eb-8821-d1cf3589ada4", "colab": { "base_uri": "https://localhost:8080/", "height": 578 } }, "source": [ "for lr in lrs:\n", " model = LR()\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0., weight_decay=1e-3)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", " torch.save(states, './MNIST/LR_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.05, momentum=0. :\n", "Epoch 0, training loss 0.8053867536963839, time passed 0m 4s\n", "Validation loss 0.4970751264516045, validation accuracy 0.8594444444444445\n", "Epoch 4, training loss 0.3556738823953301, time passed 0m 29s\n", "Validation loss 0.34582727095660043, validation accuracy 0.8994444444444445\n", "Epoch 8, training loss 0.324508320201527, time passed 0m 55s\n", "Validation loss 0.32668687315548167, validation accuracy 0.9034444444444445\n", "Epoch 12, training loss 0.31013485217335246, time passed 1m 17s\n", "Validation loss 0.3108014885117026, validation accuracy 0.9085555555555556\n", "Epoch 16, training loss 0.3015503261727516, time passed 1m 40s\n", "Validation loss 0.3028887212276459, validation accuracy 0.9145555555555556\n", "SGD lr=0.01, momentum=0. :\n", "Epoch 0, training loss 1.2965486091796798, time passed 0m 4s\n", "Validation loss 0.8687916503233069, validation accuracy 0.8221111111111111\n", "Epoch 4, training loss 0.49756273416557695, time passed 0m 27s\n", "Validation loss 0.4874144634779762, validation accuracy 0.8711111111111111\n", "Epoch 8, training loss 0.42324275109503007, time passed 0m 49s\n", "Validation loss 0.4176199418656966, validation accuracy 0.8856666666666667\n", "Epoch 12, training loss 0.39122442795772744, time passed 1m 14s\n", "Validation loss 0.38835134050425363, validation accuracy 0.8913333333333333\n", "Epoch 16, training loss 0.3713635543380121, time passed 1m 36s\n", "Validation loss 0.3647502099766451, validation accuracy 0.8964444444444445\n", "SGD lr=0.005, momentum=0. :\n", "Epoch 0, training loss 1.6245008020689993, time passed 0m 4s\n", "Validation loss 1.1842015701181747, validation accuracy 0.7812222222222223\n", "Epoch 4, training loss 0.6250864417866023, time passed 0m 27s\n", "Validation loss 0.6046268624417922, validation accuracy 0.8561111111111112\n", "Epoch 8, training loss 0.506018156054044, time passed 0m 53s\n", "Validation loss 0.498588469098596, validation accuracy 0.87\n", "Epoch 12, training loss 0.45563414421948517, time passed 1m 15s\n", "Validation loss 0.45264510898029103, validation accuracy 0.8791111111111111\n", "Epoch 16, training loss 0.4259264917686732, time passed 1m 38s\n", "Validation loss 0.42056286334991455, validation accuracy 0.883\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "b1CbtL-F0xaF", "colab_type": "code", "outputId": "fb4ed59a-de06-478a-d1c7-d972783eb128", "colab": { "base_uri": "https://localhost:8080/", "height": 187 } }, "source": [ "l_0 = 20\n", "model = LR()\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0., weight_decay=1e-3)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", "torch.save(states, './MNIST/LR_sug')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Epoch 0, training loss 0.8922390636771617, time passed 0m 4s\n", "Validation loss 0.5756662032183479, validation accuracy 0.8471111111111111\n", "Epoch 4, training loss 0.36624673069125474, time passed 0m 29s\n", "Validation loss 0.35796097797505994, validation accuracy 0.8953333333333333\n", "Epoch 8, training loss 0.3267748903746557, time passed 0m 54s\n", "Validation loss 0.32799020935507384, validation accuracy 0.906\n", "Epoch 12, training loss 0.30954578579074205, time passed 1m 18s\n", "Validation loss 0.31454072805011973, validation accuracy 0.9114444444444444\n", "Epoch 16, training loss 0.29919745840809564, time passed 1m 42s\n", "Validation loss 0.2986759543418884, validation accuracy 0.9141111111111111\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "uJ24b8FFpIQU", "colab_type": "text" }, "source": [ "### FC" ] }, { "cell_type": "code", "metadata": { "id": "zBxqCh7YEOok", "colab_type": "code", "colab": {} }, "source": [ "n_epochs = 20" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KlkSl5x-vPRW", "colab_type": "code", "outputId": "3adf6f03-38c0-4112-be46-3ea823f40399", "colab": { "base_uri": "https://localhost:8080/", "height": 578 } }, "source": [ "torch.manual_seed(999)\n", "for lr in lrs:\n", " model = FC()\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0., weight_decay=1e-3)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", " torch.save(states, './MNIST/FC_' + str(lr))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.05, momentum=0. :\n", "Epoch 0, training loss 0.9827026613432952, time passed 0m 4s\n", "Validation loss 0.521831205662559, validation accuracy 0.8611111111111112\n", "Epoch 4, training loss 0.32333635019533563, time passed 0m 27s\n", "Validation loss 0.33837247420759764, validation accuracy 0.8983333333333333\n", "Epoch 8, training loss 0.2735136151313782, time passed 0m 52s\n", "Validation loss 0.27770407059613395, validation accuracy 0.9204444444444444\n", "Epoch 12, training loss 0.2350455638435152, time passed 1m 14s\n", "Validation loss 0.2334472729879267, validation accuracy 0.9327777777777778\n", "Epoch 16, training loss 0.20252520448029643, time passed 1m 37s\n", "Validation loss 0.21084140153492198, validation accuracy 0.9381111111111111\n", "SGD lr=0.01, momentum=0. :\n", "Epoch 0, training loss 1.8455105502196032, time passed 0m 6s\n", "Validation loss 1.3969286049113554, validation accuracy 0.7461111111111111\n", "Epoch 4, training loss 0.5331706323406913, time passed 0m 28s\n", "Validation loss 0.5103560914011562, validation accuracy 0.8694444444444445\n", "Epoch 8, training loss 0.4116000948530255, time passed 0m 52s\n", "Validation loss 0.4016827555263744, validation accuracy 0.8874444444444445\n", "Epoch 12, training loss 0.3670172348166957, time passed 1m 17s\n", "Validation loss 0.3593411305371453, validation accuracy 0.8961111111111111\n", "Epoch 16, training loss 0.34106487396991614, time passed 1m 42s\n", "Validation loss 0.3356891782844768, validation accuracy 0.901\n", "SGD lr=0.005, momentum=0. :\n", "Epoch 0, training loss 2.0741889308197328, time passed 0m 4s\n", "Validation loss 1.842525559313157, validation accuracy 0.6163333333333333\n", "Epoch 4, training loss 0.8207536977950973, time passed 0m 27s\n", "Validation loss 0.7696297238854801, validation accuracy 0.832\n", "Epoch 8, training loss 0.5540273183524006, time passed 0m 52s\n", "Validation loss 0.5414836915100322, validation accuracy 0.8632222222222222\n", "Epoch 12, training loss 0.46390234430631, time passed 1m 14s\n", "Validation loss 0.4573543685324052, validation accuracy 0.8742222222222222\n", "Epoch 16, training loss 0.4175319692703209, time passed 1m 36s\n", "Validation loss 0.4172443712458891, validation accuracy 0.8842222222222222\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "ODga9MkTvnRl", "colab_type": "code", "outputId": "1a7ece1b-0258-4dbf-a590-07b6c51a779a", "colab": { "base_uri": "https://localhost:8080/", "height": 187 } }, "source": [ "l_0 = 20\n", "model = FC()\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0., weight_decay=1e-3)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", "torch.save(states, './MNIST/FC_sug')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Epoch 0, training loss 0.841091928457973, time passed 0m 5s\n", "Validation loss 0.5276602085898904, validation accuracy 0.862\n", "Epoch 4, training loss 0.3375983780080622, time passed 0m 31s\n", "Validation loss 0.32630446903845844, validation accuracy 0.9072222222222223\n", "Epoch 8, training loss 0.2841122362649802, time passed 0m 55s\n", "Validation loss 0.28937043161953196, validation accuracy 0.914\n", "Epoch 12, training loss 0.2409512424709821, time passed 1m 19s\n", "Validation loss 0.24036436747102177, validation accuracy 0.9277777777777778\n", "Epoch 16, training loss 0.20427992352933594, time passed 1m 45s\n", "Validation loss 0.20693528213921716, validation accuracy 0.9405555555555556\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "OY6FX8D6lNct", "colab_type": "text" }, "source": [ "+ momentum" ] }, { "cell_type": "code", "metadata": { "id": "nqonb8JxlM8M", "colab_type": "code", "outputId": "b9840fe8-9616-46df-8e1d-9d817724e597", "colab": { "base_uri": "https://localhost:8080/", "height": 578 } }, "source": [ "torch.manual_seed(999)\n", "for lr in lrs:\n", " model = FC()\n", " print(\"SGD lr={}, momentum=0.9 :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", " torch.save(states, './MNIST/FC_' + str(lr)+'_0.9')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "SGD lr=0.05, momentum=0.9 :\n", "Epoch 0, training loss 0.5381249714680393, time passed 0m 5s\n", "Validation loss 0.28269324670819673, validation accuracy 0.9155555555555556\n", "Epoch 4, training loss 0.12210867298070831, time passed 0m 29s\n", "Validation loss 0.1238884250907337, validation accuracy 0.9644444444444444\n", "Epoch 8, training loss 0.08087855562417194, time passed 0m 56s\n", "Validation loss 0.11202001571655273, validation accuracy 0.9683333333333334\n", "Epoch 12, training loss 0.06308559685794994, time passed 1m 18s\n", "Validation loss 0.08321712078417048, validation accuracy 0.9768888888888889\n", "Epoch 16, training loss 0.05535656740569105, time passed 1m 40s\n", "Validation loss 0.09248179460273069, validation accuracy 0.9728888888888889\n", "SGD lr=0.01, momentum=0.9 :\n", "Epoch 0, training loss 0.8471114117689807, time passed 0m 4s\n", "Validation loss 0.39922235643162446, validation accuracy 0.886\n", "Epoch 4, training loss 0.2757103128565682, time passed 0m 29s\n", "Validation loss 0.26469812410719257, validation accuracy 0.9244444444444444\n", "Epoch 8, training loss 0.20710186073274323, time passed 0m 51s\n", "Validation loss 0.2069076071767246, validation accuracy 0.9395555555555556\n", "Epoch 12, training loss 0.16193968215675064, time passed 1m 13s\n", "Validation loss 0.16249041434596567, validation accuracy 0.9536666666666667\n", "Epoch 16, training loss 0.1327487767645807, time passed 1m 39s\n", "Validation loss 0.1429281318012406, validation accuracy 0.9581111111111111\n", "SGD lr=0.005, momentum=0.9 :\n", "Epoch 0, training loss 1.1291484489585415, time passed 0m 4s\n", "Validation loss 0.5271653725820429, validation accuracy 0.8605555555555555\n", "Epoch 4, training loss 0.32374356130156856, time passed 0m 26s\n", "Validation loss 0.31370861039442177, validation accuracy 0.909\n", "Epoch 8, training loss 0.2752047236820664, time passed 0m 49s\n", "Validation loss 0.2723114639520645, validation accuracy 0.9206666666666666\n", "Epoch 12, training loss 0.2382019865091401, time passed 1m 15s\n", "Validation loss 0.23456529834691217, validation accuracy 0.9323333333333333\n", "Epoch 16, training loss 0.20731118155850303, time passed 1m 37s\n", "Validation loss 0.21120107173919678, validation accuracy 0.9417777777777778\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Ttr66r13pIoX", "colab_type": "code", "outputId": "fd9835c6-f08a-40dd-b724-0df52ee042ab", "colab": { "base_uri": "https://localhost:8080/", "height": 187 } }, "source": [ "torch.manual_seed(999)\n", "l_0 = 20\n", "model = FC()\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.9, weight_decay=1e-3)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", "torch.save(states, './MNIST/FC_sug_0.9')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Epoch 0, training loss 0.5774111350377401, time passed 0m 5s\n", "Validation loss 0.4330375124426449, validation accuracy 0.8753333333333333\n", "Epoch 4, training loss 0.35169939471013617, time passed 0m 29s\n", "Validation loss 0.3514442969770992, validation accuracy 0.8977777777777778\n", "Epoch 8, training loss 0.3251004270230881, time passed 0m 55s\n", "Validation loss 0.3274660198127522, validation accuracy 0.9045555555555556\n", "Epoch 12, training loss 0.309954424398114, time passed 1m 20s\n", "Validation loss 0.3091518423136543, validation accuracy 0.9075555555555556\n", "Epoch 16, training loss 0.2942978652438732, time passed 1m 47s\n", "Validation loss 0.2984069471850115, validation accuracy 0.9126666666666666\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "2x_ieelg84rE", "colab_type": "text" }, "source": [ "### CNN" ] }, { "cell_type": "code", "metadata": { "id": "PVXwssDp83x6", "colab_type": "code", "colab": {} }, "source": [ "for lr in lrs:\n", " model = CNN()\n", " print(\"SGD lr={}, momentum=0. :\".format(lr))\n", " optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0., weight_decay=1e-3)\n", " tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", " states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sgd'][lr],\n", " 'val_loss' : val_loss['sgd'][lr],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", " torch.save(states, './MNIST/CNN_' + str(lr))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "YTYmncEz87xM", "colab_type": "code", "colab": {} }, "source": [ "#n_epochs = 8\n", "l_0 = 20\n", "model = CNN()\n", "optimizer = SUG(model.parameters(), l_0=l_0, momentum=0., weight_decay=1e-3)\n", "tr_loss['sug'], times, val_loss['sug'], lips, grad, acc = train(model, trainloader, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=validloader)\n", "states = {\n", " 'epoch': n_epochs,\n", " 'state_dict': model.state_dict(),\n", " 'optimizer': optimizer.state_dict(),\n", " 'tr_loss' : tr_loss['sug'],\n", " 'val_loss' : val_loss['sug'],\n", " 'lips' : lips,\n", " 'grad' : grad,\n", " 'times' : times,\n", " 'acc' : acc\n", " }\n", "torch.save(states, './MNIST/CNN_sug')" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "XhhFxcDMxjc1", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }