{ "cells": [ { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2019-06-11T15:13:14.783671Z", "start_time": "2019-06-11T15:13:14.322919Z" }, "collapsed": true }, "outputs": [], "source": [ "import os\n", "import argparse\n", "import logging\n", "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from torchdiffeq import odeint\n", "import torchvision.datasets as datasets\n", "import torchvision.transforms as transforms\n", "\n", "from matplotlib import pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sns.set()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def conv3x3(in_planes, out_planes, stride=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n", "\n", "\n", "def conv1x1(in_planes, out_planes, stride=1):\n", " \"\"\"1x1 convolution\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n", "\n", "\n", "def norm(dim):\n", " return nn.GroupNorm(min(32, dim), dim)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class ResBlock(nn.Module):\n", " def __init__(self, inplanes, planes, stride=1, downsample=None):\n", " super(ResBlock, self).__init__()\n", " self.norm1 = norm(inplanes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.norm2 = norm(planes)\n", " self.conv2 = conv3x3(planes, planes)\n", "\n", " def forward(self, x):\n", " shortcut = x\n", "\n", " out = self.relu(self.norm1(x))\n", "\n", " if self.downsample is not None:\n", " shortcut = self.downsample(out)\n", "\n", " out = self.conv1(out)\n", " out = self.norm2(out)\n", " out = self.relu(out)\n", " out = self.conv2(out)\n", "\n", " return out + shortcut" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class ConcatConv2d(nn.Module):\n", "\n", " def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):\n", " super(ConcatConv2d, self).__init__()\n", " module = nn.ConvTranspose2d if transpose else nn.Conv2d\n", " self._layer = module(\n", " dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,\n", " bias=bias\n", " )\n", "\n", " def forward(self, t, x):\n", " tt = torch.ones_like(x[:, :1, :, :]) * t\n", " ttx = torch.cat([tt, x], 1)\n", " return self._layer(ttx)\n", "\n", "\n", "class ODEfunc_compl(nn.Module):\n", "\n", " def __init__(self, dim):\n", " super(ODEfunc, self).__init__()\n", " self.norm1 = norm(dim)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm2 = norm(dim)\n", " self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm3 = norm(dim)\n", " self.conv3 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm4 = norm(dim)\n", " self.conv4 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm5 = norm(dim)\n", " self.nfe = 0\n", "\n", " def forward(self, t, x):\n", " self.nfe += 1\n", " out = self.norm1(x)\n", " out = self.relu(out)\n", " out = self.conv1(t, out)\n", " out = self.norm2(out)\n", " out = self.relu(out)\n", " out = self.conv2(t, out)\n", " out = self.norm3(out)\n", " out = self.relu(out)\n", " out = self.conv3(t, out)\n", " out = self.norm4(out)\n", " out = self.relu(out)\n", " out = self.conv4(t, out)\n", " out = self.norm5(out)\n", " return out\n", "\n", " \n", "class ODEfunc(nn.Module):\n", "\n", " def __init__(self, dim):\n", " super(ODEfunc, self).__init__()\n", " self.norm1 = norm(dim)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm2 = norm(dim)\n", " self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)\n", " self.norm3 = norm(dim)\n", " self.nfe = 0\n", "\n", " def forward(self, t, x):\n", " self.nfe += 1\n", " out = self.norm1(x)\n", " out = self.relu(out)\n", " out = self.conv1(t, out)\n", " out = self.norm2(out)\n", " out = self.relu(out)\n", " out = self.conv2(t, out)\n", " out = self.norm3(out)\n", " return out\n", "\n", "\n", "class ODEBlock(nn.Module):\n", "\n", " def __init__(self, odefunc):\n", " super(ODEBlock, self).__init__()\n", " self.odefunc = odefunc\n", " self.integration_time = torch.tensor([0, 1]).float()\n", "\n", " def forward(self, x):\n", " self.integration_time = self.integration_time.type_as(x)\n", " out = odeint(self.odefunc, x, self.integration_time, rtol=tol, atol=tol)\n", " return out[1]\n", "\n", " @property\n", " def nfe(self):\n", " return self.odefunc.nfe\n", "\n", " @nfe.setter\n", " def nfe(self, value):\n", " self.odefunc.nfe = value" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Flatten(nn.Module):\n", "\n", " def __init__(self):\n", " super(Flatten, self).__init__()\n", "\n", " def forward(self, x):\n", " shape = torch.prod(torch.tensor(x.shape[1:])).item()\n", " return x.view(-1, shape)\n", "\n", "\n", "class RunningAverageMeter(object):\n", " \"\"\"Computes and stores the average and current value\"\"\"\n", "\n", " def __init__(self, momentum=0.99):\n", " self.momentum = momentum\n", " self.reset()\n", "\n", " def reset(self):\n", " self.val = None\n", " self.avg = 0\n", "\n", " def update(self, val):\n", " if self.val is None:\n", " self.avg = val\n", " else:\n", " self.avg = self.avg * self.momentum + val * (1 - self.momentum)\n", " self.val = val" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, perc=1.0):\n", " if data_aug:\n", " transform_train = transforms.Compose([\n", " transforms.RandomCrop(28, padding=4),\n", " transforms.ToTensor(),\n", " ])\n", " else:\n", " transform_train = transforms.Compose([\n", " transforms.ToTensor(),\n", " ])\n", "\n", " transform_test = transforms.Compose([\n", " transforms.ToTensor(),\n", " ])\n", "\n", " train_loader = DataLoader(\n", " datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,\n", " shuffle=True, num_workers=2, drop_last=True\n", " )\n", "\n", " train_eval_loader = DataLoader(\n", " datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),\n", " batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True\n", " )\n", "\n", " test_loader = DataLoader(\n", " datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),\n", " batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True\n", " )\n", "\n", " return train_loader, test_loader, train_eval_loader\n", "\n", "\n", "def inf_generator(iterable):\n", " \"\"\"Allows training with DataLoaders in a single infinite loop:\n", " for i, (x, y) in enumerate(inf_generator(train_loader)):\n", " \"\"\"\n", " iterator = iterable.__iter__()\n", " while True:\n", " try:\n", " yield iterator.__next__()\n", " except StopIteration:\n", " iterator = iterable.__iter__()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def learning_rate_with_decay(lr, batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):\n", " initial_learning_rate = lr * batch_size / batch_denom\n", "\n", " boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]\n", " vals = [initial_learning_rate * decay for decay in decay_rates]\n", "\n", " def learning_rate_fn(itr):\n", " lt = [itr < b for b in boundaries] + [True]\n", " i = np.argmax(lt)\n", " return vals[i]\n", "\n", " return learning_rate_fn\n", "\n", "\n", "def one_hot(x, K):\n", " return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)\n", "\n", "\n", "def accuracy(model, dataset_loader):\n", " total_correct = 0\n", " for x, y in dataset_loader:\n", " x = x.to(device)\n", " y = one_hot(np.array(y.numpy()), 10)\n", "\n", " target_class = np.argmax(y, axis=1)\n", " predicted_class = np.argmax(model(x).cpu().detach().numpy(), axis=1)\n", " total_correct += np.sum(predicted_class == target_class)\n", " return total_correct / len(dataset_loader.dataset)\n", "\n", "\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def makedirs(dirname):\n", " if not os.path.exists(dirname):\n", " os.makedirs(dirname)\n", "\n", "\n", "def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):\n", " logger = logging.getLogger()\n", " if debug:\n", " level = logging.DEBUG\n", " else:\n", " level = logging.INFO\n", " logger.setLevel(level)\n", " if saving:\n", " info_file_handler = logging.FileHandler(logpath, mode=\"a\")\n", " info_file_handler.setLevel(level)\n", " logger.addHandler(info_file_handler)\n", " if displaying:\n", " console_handler = logging.StreamHandler()\n", " console_handler.setLevel(level)\n", " logger.addHandler(console_handler)\n", " #logger.info(filepath)\n", " #with open(filepath, \"r\") as f:\n", " # logger.info(f.read())\n", "\n", " for f in package_files:\n", " logger.info(f)\n", " with open(f, \"r\") as package_f:\n", " logger.info(package_f.read())\n", "\n", " return logger" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "tol = 1e-3\n", "downsampling_method = 'res'\n", "epochs=100\n", "data_aug = True\n", "lr = 1e-3\n", "batch_size = 128\n", "test_bs = 1024 #??\n", "\n", "save_dir = './exp1' " ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#from torchdiffeq import odeint_adjoint as odeint\n", "from torchdiffeq import odeint" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (1): ResBlock(\n", " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace)\n", " (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " )\n", " (2): ResBlock(\n", " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace)\n", " (downsample): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " )\n", " (3): ODEBlock(\n", " (odefunc): ODEfunc(\n", " (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace)\n", " (conv1): ConcatConv2d(\n", " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (conv2): ConcatConv2d(\n", " (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " )\n", " (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " )\n", " )\n", " (4): GroupNorm(32, 64, eps=1e-05, affine=True)\n", " (5): ReLU(inplace)\n", " (6): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (7): Flatten()\n", " (8): Linear(in_features=64, out_features=10, bias=True)\n", ")\n", "Number of parameters: 232970\n", "Epoch 0 | Time 1.166 (1.166) | NFE-F 26.0 | NFE-B 0.0 | Train_Acc 0.0978 | Test_Acc 0.0897 | Total_Time 1.166\n", "Epoch 1 | Time 1.124 (0.352) | NFE-F 26.1 | NFE-B 0.0 | Train_Acc 0.9630 | Test_Acc 0.8998 | Total_Time 151.212\n", "Epoch 2 | Time 1.169 (0.356) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9720 | Test_Acc 0.9077 | Total_Time 159.991\n", "Epoch 3 | Time 1.162 (0.354) | NFE-F 26.2 | NFE-B 0.0 | Train_Acc 0.9730 | Test_Acc 0.9052 | Total_Time 160.550\n", "Epoch 4 | Time 1.203 (0.341) | NFE-F 25.0 | NFE-B 0.0 | Train_Acc 0.9805 | Test_Acc 0.9138 | Total_Time 157.097\n", "Epoch 5 | Time 1.123 (0.320) | NFE-F 22.9 | NFE-B 0.0 | Train_Acc 0.9821 | Test_Acc 0.9130 | Total_Time 147.981\n", "Epoch 6 | Time 1.177 (0.295) | NFE-F 20.9 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9129 | Total_Time 135.925\n", "Epoch 7 | Time 1.094 (0.286) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9808 | Test_Acc 0.9141 | Total_Time 128.339\n", "Epoch 8 | Time 1.158 (0.285) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9823 | Test_Acc 0.9148 | Total_Time 128.412\n", "Epoch 9 | Time 1.123 (0.286) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9827 | Test_Acc 0.9146 | Total_Time 128.836\n", "Epoch 10 | Time 1.143 (0.289) | NFE-F 20.2 | NFE-B 0.0 | Train_Acc 0.9813 | Test_Acc 0.9134 | Total_Time 130.564\n", "Epoch 11 | Time 1.171 (0.296) | NFE-F 20.1 | NFE-B 0.0 | Train_Acc 0.9812 | Test_Acc 0.9130 | Total_Time 131.606\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 70\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 71\u001b[1;33m \u001b[0mlogits\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 72\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 73\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 492\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 494\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 92\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 93\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 94\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 492\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 494\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0mshortcut\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdownsample\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 19\u001b[1;33m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 20\u001b[0m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 492\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 493\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 494\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 336\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[0;32m 337\u001b[0m return F.conv2d(input, self.weight, self.bias, self.stride,\n\u001b[1;32m--> 338\u001b[1;33m self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[0;32m 339\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 340\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "makedirs(save_dir)\n", "logger = get_logger(logpath=os.path.join(save_dir, 'logs'), filepath=os.path.abspath('.'))\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "is_odenet = True\n", "\n", "if downsampling_method == 'conv':\n", " downsampling_layers = [\n", " nn.Conv2d(1, 64, 3, 1),\n", " norm(64),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(64, 64, 4, 2, 1),\n", " norm(64),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(64, 64, 4, 2, 1),\n", " ]\n", "elif downsampling_method == 'res':\n", " downsampling_layers = [\n", " nn.Conv2d(1, 64, 3, 1),\n", " ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),\n", " ResBlock(64, 64, stride=2, downsample=conv1x1(64, 64, 2)),\n", " ]\n", "\n", "#feature_layers = [ODEBlock(ODEfunc(64)) for _ in range(2)] if is_odenet else [ResBlock(64, 64) for _ in range(6)]\n", "#feature_layers = [ODEBlock(ODEfunc_compl(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]\n", "feature_layers = [ODEBlock(ODEfunc(64))] if is_odenet else [ResBlock(64, 64) for _ in range(6)]\n", "\n", "fc_layers = [norm(64), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]\n", "\n", "model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)\n", "\n", "#logger.info\n", "print(model)\n", "#logger.info\n", "print('Number of parameters: {}'.format(count_parameters(model)))\n", "\n", "criterion = nn.CrossEntropyLoss().to(device)\n", "\n", "train_loader, test_loader, train_eval_loader = get_mnist_loaders(data_aug, batch_size, test_bs)\n", "\n", "data_gen = inf_generator(train_loader)\n", "batches_per_epoch = len(train_loader)\n", "\n", "lr_fn = learning_rate_with_decay( lr, batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch,\n", " boundary_epochs=[20, 40, 60],decay_rates=[1, 0.1, 0.01, 0.001])\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "\n", "best_acc = 0\n", "field_names = list(map(lambda l:l.strip().split()[0],\n", " \"Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} |\"\n", " \"Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}\".split('|')))\n", "logs_1 = []\n", "\n", "batch_time_meter = RunningAverageMeter()\n", "f_nfe_meter = RunningAverageMeter()\n", "b_nfe_meter = RunningAverageMeter()\n", "end = time.time()\n", "start = time.time()\n", "\n", "for itr in range(epochs * batches_per_epoch):\n", "\n", " for param_group in optimizer.param_groups:\n", " param_group['lr'] = lr_fn(itr)\n", "\n", " optimizer.zero_grad()\n", " x, y = data_gen.__next__()\n", " x = x.to(device)\n", " y = y.to(device)\n", " logits = model(x)\n", " loss = criterion(logits, y)\n", "\n", " if is_odenet:\n", " nfe_forward = feature_layers[0].nfe\n", " feature_layers[0].nfe = 0\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if is_odenet:\n", " nfe_backward = feature_layers[0].nfe\n", " feature_layers[0].nfe = 0\n", "\n", " batch_time_meter.update(time.time() - end)\n", " if is_odenet:\n", " f_nfe_meter.update(nfe_forward)\n", " b_nfe_meter.update(nfe_backward)\n", " end = time.time()\n", "\n", " if itr % batches_per_epoch == 0:\n", " finish = time.time()\n", " with torch.no_grad():\n", " train_acc = accuracy(model, train_eval_loader)\n", " val_acc = accuracy(model, test_loader)\n", " if val_acc > best_acc:\n", " torch.save({'state_dict': model.state_dict()}, os.path.join(save_dir, 'model.pth'))\n", " best_acc = val_acc\n", " #logger.info\n", " print(\n", " \"Epoch {:4d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | \"\n", " \"Train_Acc {:.4f} | Test_Acc {:.4f} | Total_Time {:.3f}\".format(\n", " itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,\n", " b_nfe_meter.avg, train_acc, val_acc, finish - start\n", " )\n", " )\n", " logs_1.append(dict(zip(field_names, [ itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,\n", " b_nfe_meter.avg, train_acc, val_acc, finish - start])))\n", " start = time.time()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[{'Epoch': 0,\n", " 'Time': 1.1663413047790527,\n", " 'NFE-F': 1.1663413047790527,\n", " 'NFE-B': 26,\n", " 'Train_Acc': 0,\n", " 'Test_Acc': 0.09775,\n", " 'Total_Time': 0.0897},\n", " {'Epoch': 1,\n", " 'Time': 1.1237423419952393,\n", " 'NFE-F': 0.3518146520107296,\n", " 'NFE-B': 26.113337826255613,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9629833333333333,\n", " 'Total_Time': 0.8998},\n", " {'Epoch': 2,\n", " 'Time': 1.1690211296081543,\n", " 'NFE-F': 0.35646780732408767,\n", " 'NFE-B': 26.160499026848903,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.972,\n", " 'Total_Time': 0.9077},\n", " {'Epoch': 3,\n", " 'Time': 1.1619246006011963,\n", " 'NFE-F': 0.3541097996338256,\n", " 'NFE-B': 26.160926447915926,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9729833333333333,\n", " 'Total_Time': 0.9052},\n", " {'Epoch': 4,\n", " 'Time': 1.202782154083252,\n", " 'NFE-F': 0.34110854730643125,\n", " 'NFE-B': 24.95415771513177,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9805,\n", " 'Total_Time': 0.9138},\n", " {'Epoch': 5,\n", " 'Time': 1.12302827835083,\n", " 'NFE-F': 0.3201531668305314,\n", " 'NFE-B': 22.928928317852208,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9820666666666666,\n", " 'Total_Time': 0.913},\n", " {'Epoch': 6,\n", " 'Time': 1.177217721939087,\n", " 'NFE-F': 0.2949909653493132,\n", " 'NFE-B': 20.920042885241745,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9807833333333333,\n", " 'Total_Time': 0.9129},\n", " {'Epoch': 7,\n", " 'Time': 1.0944006443023682,\n", " 'NFE-F': 0.28566477432615045,\n", " 'NFE-B': 20.216727656597406,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9808,\n", " 'Total_Time': 0.9141},\n", " {'Epoch': 8,\n", " 'Time': 1.1575982570648193,\n", " 'NFE-F': 0.28541028713199734,\n", " 'NFE-B': 20.184034851164867,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9823,\n", " 'Total_Time': 0.9148},\n", " {'Epoch': 9,\n", " 'Time': 1.1227002143859863,\n", " 'NFE-F': 0.28647260674155234,\n", " 'NFE-B': 20.125482535545512,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9827333333333333,\n", " 'Total_Time': 0.9146},\n", " {'Epoch': 10,\n", " 'Time': 1.1429314613342285,\n", " 'NFE-F': 0.28855354211729095,\n", " 'NFE-B': 20.15380365074734,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9812666666666666,\n", " 'Total_Time': 0.9134},\n", " {'Epoch': 11,\n", " 'Time': 1.1712427139282227,\n", " 'NFE-F': 0.29551518088596457,\n", " 'NFE-B': 20.124064572092596,\n", " 'Train_Acc': 0.0,\n", " 'Test_Acc': 0.9812,\n", " 'Total_Time': 0.913}]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logs_1" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(15, 10))\n", "plt.xlabel('epoch number')\n", "plt.ylabel('accuracy')\n", "plt.title('ODENET complex block')\n", "plt.plot(np.arange(len(train_ac_3[1:])) + 1, train_ac_3[1:], label='Accuracy on training set', marker='v')\n", "plt.plot(np.arange(len(test_ac_3[1:])) + 1, test_ac_3[1:], label='Accuracy on test set', marker='o')\n", "plt.legend()\n", "plt.savefig('odenet_3.pdf', bbox_inches='tight')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "test_ac_3 = [elem['Total_Time'] for elem in logs_3]\n", "train_ac_3 = [elem['Test_Acc'] for elem in logs_3]\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "test_ac_1 = [elem['Total_Time'] for elem in logs_1]\n", "train_ac_1 = [elem['Test_Acc'] for elem in logs_1]\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(15, 10))\n", "plt.xlabel('epoch number')\n", "plt.ylabel('accuracy')\n", "plt.title('ODENET with one block')\n", "plt.plot(np.arange(len(train_ac_1[1:])) + 1, train_ac_1[1:], label='Accuracy on training set', marker='v')\n", "plt.plot(np.arange(len(test_ac_1[1:])) + 1, test_ac_1[1:], label='Accuracy on test set', marker='o')\n", "plt.legend()\n", "plt.savefig('odenet_1.pdf', bbox_inches='tight')" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 2 }