{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CIFAR 10" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import os\n", "import shutil\n", "import time\n", "\n", "from fastai.transforms import *\n", "from fastai.dataset import *\n", "from fastai.fp16 import *\n", "from fastai.conv_learner import *\n", "from pathlib import *\n", "\n", "import torch\n", "from torch.autograd import Variable\n", "import torch.nn as nn\n", "import torch.nn.parallel\n", "import torch.backends.cudnn as cudnn\n", "import torch.distributed as dist\n", "import torch.optim\n", "import torch.utils.data\n", "import torch.utils.data.distributed\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as datasets\n", "import models\n", "import models.cifar10 as cifar10models\n", "from distributed import DistributedDataParallel as DDP\n", "\n", "# print(models.cifar10.__dict__)\n", "model_names = sorted(name for name in models.__dict__\n", " if name.islower() and not name.startswith(\"__\")\n", " and callable(models.__dict__[name]))\n", "\n", "cifar10_names = sorted(name for name in cifar10models.__dict__\n", " if name.islower() and not name.startswith(\"__\")\n", " and callable(cifar10models.__dict__[name]))\n", "\n", "model_names = cifar10_names + model_names" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Example usage: python run_fastai.py /home/paperspace/ILSVRC/Data/CLS-LOC/ -a resnext_50_32x4d --epochs 1 -j 4 -b 64 --fp16\n", "\n", "parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n", "parser.add_argument('data', metavar='DIR',\n", " help='path to dataset')\n", "parser.add_argument('--save-dir', type=str, default=Path.home()/'imagenet_training',\n", " help='Directory to save logs and models.')\n", "parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',\n", " choices=model_names,\n", " help='model architecture: ' +\n", " ' | '.join(model_names) +\n", " ' (default: resnet18)')\n", "parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n", " help='number of data loading workers (default: 4)')\n", "parser.add_argument('--epochs', default=90, type=int, metavar='N',\n", " help='number of total epochs to run')\n", "parser.add_argument('--cycle-len', default=1, type=float, metavar='N',\n", " help='Length of cycle to run')\n", "# parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n", "# help='manual epoch number (useful on restarts)')\n", "parser.add_argument('-b', '--batch-size', default=256, type=int,\n", " metavar='N', help='mini-batch size (default: 256)')\n", "parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n", " metavar='LR', help='initial learning rate')\n", "parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n", "parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n", " metavar='W', help='weight decay (default: 1e-4)')\n", "# parser.add_argument('--print-freq', '-p', default=10, type=int,\n", "# metavar='N', help='print frequency (default: 10)')\n", "# parser.add_argument('--resume', default='', type=str, metavar='PATH',\n", "# help='path to latest checkpoint (default: none)')\n", "# parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',\n", "# help='evaluate model on validation set')\n", "parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')\n", "parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')\n", "parser.add_argument('--use-tta', action='store_true', help='Validate model with TTA at the end of traiing.')\n", "parser.add_argument('--train-128', action='store_true', help='Train model on 128. TODO: allow custom epochs and LR')\n", "parser.add_argument('--sz', default=224, type=int, help='Size of transformed image.')\n", "# parser.add_argument('--decay-int', default=30, type=int, help='Decay LR by 10 every decay-int epochs')\n", "parser.add_argument('--use_clr', type=str, \n", " help='div,pct,max_mom,min_mom. Pass in a string delimited by commas. Ex: \"20,2,0.95,0.85\"')\n", "parser.add_argument('--loss-scale', type=float, default=1,\n", " help='Loss scaling, positive power of 2 values can improve fp16 convergence.')\n", "parser.add_argument('--prof', dest='prof', action='store_true', help='Only run a few iters for profiling.')\n", "\n", "parser.add_argument('--dist-url', default='file://sync.file', type=str,\n", " help='url used to set up distributed training')\n", "parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')\n", "\n", "parser.add_argument('--world-size', default=1, type=int,\n", " help='Number of GPUs to use. Can either be manually set ' +\n", " 'or automatically set by using \\'python -m multiproc\\'.')\n", "parser.add_argument('--rank', default=0, type=int,\n", " help='Used for multi-process training. Can either be manually set ' +\n", " 'or automatically set by using \\'python -m multiproc\\'.')\n", "\n", "def fast_loader(data_path, size):\n", " aug_tfms = [\n", " RandomFlip(),\n", "# RandomRotate(4),\n", "# RandomLighting(0.05, 0.05),\n", " RandomCrop(size)\n", " ]\n", " classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", " cifar10_stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))\n", " tfms = tfms_from_stats(cifar10_stats, size, aug_tfms=aug_tfms, pad=args.sz//8)\n", " data = ImageClassifierData.from_paths(data_path, val_name='test', tfms=tfms, \n", " bs=args.batch_size, num_workers=args.workers)\n", "\n", " if args.distributed:\n", " train_sampler = torch.utils.data.distributed.DistributedSampler(data.trn_dl)\n", " else:\n", " train_sampler = None\n", "\n", " # TODO: Need to test train_sampler on distributed machines\n", " \n", " # Use pytorch default data loader. 20% faster\n", " data.trn_dl = torch.utils.data.DataLoader(\n", " data.trn_ds, batch_size=data.bs, shuffle=(train_sampler is None),\n", " num_workers=data.num_workers, pin_memory=True, sampler=train_sampler)\n", " data.trn_dl = DataPrefetcher(data.trn_dl)\n", "\n", " data.val_dl = torch.utils.data.DataLoader(\n", " data.val_ds,\n", " batch_size=data.bs, shuffle=False,\n", " num_workers=data.num_workers, pin_memory=True)\n", " data.val_dl = DataPrefetcher(data.val_dl, stop_early=args.prof)\n", " \n", " return data, train_sampler\n", "\n", "# Seems to speed up training by ~2%\n", "class DataPrefetcher():\n", " def __init__(self, loader, stop_early=False):\n", " self.loader = loader\n", " self.dataset = loader.dataset\n", " self.stream = torch.cuda.Stream()\n", " self.stop_early = stop_early\n", " self.next_input = None\n", " self.next_target = None\n", "\n", " def __len__(self):\n", " return len(self.loader)\n", " \n", " def preload(self):\n", " try:\n", " self.next_input, self.next_target = next(self.loaditer)\n", " except StopIteration:\n", " self.next_input = None\n", " self.next_target = None\n", " return\n", " with torch.cuda.stream(self.stream):\n", " self.next_input = self.next_input.cuda(async=True)\n", " self.next_target = self.next_target.cuda(async=True)\n", "\n", " def __iter__(self):\n", " count = 0\n", " self.loaditer = iter(self.loader)\n", " self.preload()\n", " while self.next_input is not None:\n", " torch.cuda.current_stream().wait_stream(self.stream)\n", " input = self.next_input\n", " target = self.next_target\n", " self.preload()\n", " count += 1\n", " yield input, target\n", " if self.stop_early and (count > 50):\n", " break\n", " \n", "# Taken from main.py topk accuracy\n", "def top5(output, target):\n", " \"\"\"Computes the precision@k for the specified values of k\"\"\"\n", " batch_size = target.size(0)\n", " _, pred = output.topk(5, 1, True, True)\n", " pred = pred.t()\n", " return pred.eq(target.view(1, -1).expand_as(pred)).sum()/batch_size\n", "\n", "class ValLoggingCallback(Callback):\n", " def __init__(self, save_path):\n", " super().__init__()\n", " self.save_path=save_path\n", " def on_train_begin(self):\n", " self.batch = 0\n", " self.epoch = 0\n", " self.f = open(self.save_path, \"a\", 1)\n", " def on_epoch_end(self, metrics):\n", " log_str = f'\\tEpoch:{self.epoch}\\ttrn_loss:{self.last_loss}'\n", " for (k,v) in zip(['val_loss', 'acc', 'top5', ''], metrics): log_str += f'\\t{k}:{v}'\n", " self.log(log_str)\n", " self.epoch += 1\n", " def on_batch_end(self, metrics):\n", " self.last_loss = metrics\n", " self.batch += 1\n", " def on_train_end(self):\n", " self.log(\"\\ton_train_end\")\n", " self.f.close()\n", " def log(self, string):\n", " self.f.write(time.strftime(\"%Y-%m-%dT%H:%M:%S\")+\"\\t\"+string+\"\\n\")\n", "\n", "# Logging + saving models\n", "def save_args(name, save_dir):\n", " if (args.rank != 0) or not args.save_dir: return {}\n", "\n", " log_dir = f'{save_dir}/training_logs'\n", " os.makedirs(log_dir, exist_ok=True)\n", " return {\n", " 'best_save_name': f'{name}_best_model',\n", " 'cycle_save_name': f'{name}',\n", " 'callbacks': [\n", " LoggingCallback(f'{log_dir}/{name}_log.txt'),\n", " ValLoggingCallback(f'{log_dir}/{name}_val_log.txt')\n", " ]\n", " }\n", "\n", "def save_sched(sched, save_dir):\n", " if (args.rank != 0) or not args.save_dir: return {}\n", " log_dir = f'{save_dir}/training_logs'\n", " sched.save_path = log_dir\n", " sched.plot_loss()\n", " sched.plot_lr()\n", "\n", "def update_model_dir(learner, base_dir):\n", " learner.tmp_path = f'{base_dir}/tmp'\n", " os.makedirs(learner.tmp_path, exist_ok=True)\n", " learner.models_path = f'{base_dir}/models'\n", " os.makedirs(learner.models_path, exist_ok=True)\n", " " ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "args_input = [\n", " '/home/paperspace/data/cifar10', \n", " '--save-dir', '/home/paperspace/data/cifar_training/test1', \n", " '-a', 'resnext29_16_64', \n", " '-j', '6', \n", "# '--prof', \n", " '-b', '256', \n", " '--sz', '16',\n", " '--loss-scale', '128',\n", " '--fp16',\n", " '--cycle-len', '0.05',\n", " '--epochs', '2'\n", "]" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Namespace(arch='resnext29_16_64', batch_size=256, cycle_len=0.1, data='/home/paperspace/data/cifar10', dist_backend='nccl', dist_url='file://sync.file', epochs=2, fp16=True, loss_scale=128.0, lr=0.1, momentum=0.9, pretrained=False, prof=False, rank=0, save_dir='/home/paperspace/data/cifar_training/test1', sz=16, train_128=False, use_clr=None, use_tta=False, weight_decay=0.0001, workers=6, world_size=1)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "# This is important for speed\n", "cudnn.benchmark = True\n", "global arg\n", "args = parser.parse_args(args_input); args\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "args.distributed = args.world_size > 1\n", "args.gpu = 0\n", "if args.distributed:\n", " args.gpu = args.rank % torch.cuda.device_count()\n", "\n", "if args.distributed:\n", " torch.cuda.set_device(args.gpu)\n", " dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n", " world_size=args.world_size)\n", "\n", "if args.fp16:\n", " assert torch.backends.cudnn.enabled, \"fp16 mode requires cudnn backend to be enabled.\"" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=> creating model 'resnext29_16_64'\n" ] } ], "source": [ "# create model\n", "model = cifar10models.__dict__[args.arch] if args.arch in cifar10_names else models.__dict__[args.arch] \n", "if args.pretrained:\n", " print(\"=> using pre-trained model '{}'\".format(args.arch))\n", " model = model(pretrained=True)\n", "else:\n", " print(\"=> creating model '{}'\".format(args.arch))\n", " model = model()\n", "\n", "model = model.cuda()\n", "if args.distributed:\n", " model = DDP(model)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "if args.train_128:\n", " data, train_sampler = fast_loader(f'{args.data}-160', 128)\n", "else:\n", " data, train_sampler = fast_loader(args.data, args.sz)\n", "\n", "learner = Learner.from_model_data(model, data)\n", "learner.crit = F.cross_entropy\n", "learner.metrics = [accuracy, top5, top5old]\n", "if args.fp16: learner.half()\n", "\n", "if args.prof:\n", " args.epochs = 1\n", " args.cycle_len=.01\n", "if args.use_clr:\n", " args.use_clr = tuple(map(float, args.use_clr.split(',')))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# %pdb on" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6120ec800b2349baa940b9984f9916a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 15%|█▌ | 6/39 [00:04<00:26, 1.26it/s, loss=8.64]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Exception in thread Thread-10:\n", "Traceback (most recent call last):\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n", " self.run()\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_monitor.py\", line 62, in run\n", " for instance in self.tqdm_cls._instances:\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n", " for itemref in self.data:\n", "RuntimeError: Set changed size during iteration\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40it [00:30, 1.30it/s, loss=4.02] epoch trn_loss val_loss accuracy top5 top5old \n", " 0 4.024515 2.238668 0.1757 0.6785 67.85 \n", "\n", "Finished!\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEKCAYAAAARnO4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XmcFPWd//HXZ3pmOAYEgUFAbgRUQEBHFFBBYtR4oYkxmpjkl3VDYmKMmqxrNpuNuTbZxJhNTFbFO/GItxIP1CQqihwCInJ5cMnNAMpwOMzRn98fXQPNMDP0DFNdPdT7+Xj0o6urq7s+U4/ud33rO9XfMndHREQOfXlRFyAiItmhwBcRiQkFvohITCjwRURiQoEvIhITCnwRkZhQ4IuIxIQCX0QkJhT4IiIxkR91Aem6dOniffv2jboMEZEWY+7cuZvdvTiTZXMq8Pv27cucOXOiLkNEpMUws1WZLqsuHRGRmFDgi4jEhAJfRCQmFPgiIjGhwBcRiQkFvohITCjwRURiQoEfM+7O/TNX8cq7m6IuRUSyTIEfMxXVSf4yYxXXP7aAj3ZWRF2OiGSRAj9mWuUn+O0lw9m6s4IfPb0w6nJEJIsU+DE09MgOXHPGQJ5ZsJ4pb6+LuhwRyRIFfkx9c9wARvTqyI+eWsiGbeVRlyMiWaDAj6n8RB43XzKc3VXVXP/4Atw96pJEJGQK/BjrX9yO/zjnGKa9V8oDsz6MuhwRCZkCP+YuP6kPpw7swi+eXcLKzTujLkdEQqTAj7m8POPXFx9HfsL43qNvU51U147IoUqBL3Tv0IafTRzK3FUfMXna8qjLEZGQKPAFgIkjenDOsG7c/NK7LFlfFnU5IhICBb4AYGb8/MJhdGhTyLUPz2d3VXWzvn8y6by4aAO/ffFdyiub971FJDOhBb6ZDTaz+Wm3MjO7Jqz1ycHrVFTIry8extIN2/nRUwub5VTN6qQz5e11fOb3rzHpL3O55Z8f8K0H5lFRlWyGikWkMUILfHd/191HuPsI4ARgF/BkWOuT5jHh6CO4esJRPDJnDXdPX9nk96moSvLIm6s54+ZXufqht6h253dfGM7PJg7hn0s38Z2H5lFZrdAXyab8LK3nU8Ayd8/46uoSnWvOGMS7G7fzi2cXM7BrO04bVJzxa8srq3lkzmpuf3U5az/+hCE9DuPWLx3PWUO6kZdnAFQlnZ/8bTHXPjyf3186kkQwX0TCla3AvxR4KEvrkoOUl2fcfMkIPnfrG1z14Dye+vZY+he3O+DrZi7fwjV/nc+GsnJO6HM4P79oKOMHFWO2b6B/bWw/KqqS/PL5pRTm53HTxcP37AxEJDyh/9PWzAqBC4BH63l+kpnNMbM5paWlYZcjGSpqlc8dXykhP5HHv943h22fVNa7bDLp/OnlD/jiHTNpU5jgwa+fxGPfHM3pg7vuF/Y1vjFuANd9ehBPzFvLD596h6TO/xcJXTbO0vkMMM/dN9b1pLtPdvcSdy8pLs6860DC16tTW267/ARWf7SL7zz0Vp0/yvpoZwX/ct+b/OaFdzlnWHemXDWWMQO61Bv06b4z4Si+ffoAHpq9mp/8bZHG8xEJWTYC/zLUndNijerXiZ9NHMq090r55XNL9nlu7qqPOPcPr/HGB1v42cQh3HLZSNq3Lsj4vc2M7585mK+f2o/7Zqziv59botAXCVGoffhm1hb4NPCNMNcj4bp0VG+WbtjOna+vYHC39lx8Qk/uen0Fv3p+Kd07tubxK8cwrGeHJr23mfEf5xzD7qokd7y2goJEHv921uCMjhBEpHFCDXx33wV0DnMdkh3/ee4xfLBpBz98ciFT3l7Ha+9v5sxjj+A3nx9OhzaZt+rrYmbceP4QKquT/N8ry8gz43tnDlLoizSzbJ2lIy1cfiKPP35xJBf+aTozlm3hP889hitO6ddsoZyXZ/ziwmG4wx9f/gBAoS/SzBT4krGObQt5/MoxbPukMqPTNBsrL8/474uGAQp9kTAo8KVROrdrRed2rUJ7f4W+SHgU+JJzFPoi4VDgS05S6Is0PwW+5CyFvkjzUuBLTqsd+p9UVvPDc47R2DsiTaDAl5xXE/qtCxLc9foKtu6s4NcXH0dBQtfvEWkMBb60CHl5xo/PP5Yu7Qq56cX32LqzglsvP562hfoIi2RKTSRpMcyMqyYM5JefHcZr75fyxTtm8dHOiqjLEmkxFPjS4lw2qjf/96UTWLy+jM/fPoN1H38SdUkiLYICX1qks4d248//MoqN28r53K1v8P7G7VGXJJLzFPjSYp3cvzN//cbJVFY7n799BnNXfRR1SSI5TYEvLdqQHh144soxdGhTwBfvmMkT89ZEXZJIzlLgS4vXu3NbnrhyDCN7d+S6R97mJ39bRGV1MuqyRHKOAl8OCZ3bteIvV5zE18b25Z7pK7n8zlls3rE76rJEcooCXw4ZBYk8fnz+EG6+ZDjzV3/MBbe8zoI1H0ddlkjOCDXwzayjmT1mZkvNbImZjQ5zfSIAnz2+J49fOQYz4+LbZvD4XPXri0D4LfzfA1Pd/WhgOLDkAMuLNIuhR3ZgylVjOaH34Xzv0be5cYr69UVCC3wzOww4DbgLwN0r3F3H15I1qX79UVxxSj/ufWMll9w+g9Vbd0Vdlkhkwmzh9wdKgXvM7C0zu9PMikJcn8h+8hN5/Oi8Y7nlspF8sHEH5/zhNZ5dsD7qskQiEWbg5wPHA7e6+0hgJ3BD7YXMbJKZzTGzOaWlpSGWI3F2/vAePHv1qQwobse3H5zHDY8vYFdFVdRliWRVmIG/Bljj7rOCx4+R2gHsw90nu3uJu5cUFxeHWI7EXe/ObXn0m6P51vgBPDxnNeff8jpL1pdFXZZI1oQW+O6+AVhtZoODWZ8CFoe1PpFMFCTyuP7so7n/ipMoK69i4p+m8+cZK3H3qEsTCV3YZ+l8B3jAzBYAI4D/Dnl9IhkZe1QXpn73VMYO6Mx/Pb2If71vjgZgk0Oe5VLLpqSkxOfMmRN1GRIj7s7d01dy0wvv8kllNWceewRXjh/AyN6HR12aSEbMbK67l2SyrC4XJLFmZlxxSj8uGnkk976xkvveWMmLizcyun9nvnX6AE45qosumi6HDLXwRdLs2F3FQ7M+5M7Xl7OxbDfDjuzAt8YP4Mwh3UjowumSgxrTwlfgi9Rhd1U1T85by22vLmPlll0c3a09N14whJP7d466NJF9NCbwNXiaSB1a5Se4dFRv/vG98fzhspFsL6/i0skzufqht9iwrTzq8kSaRIEv0oBEnnHB8B78/bpxXP2pgUxdtIEJv32F215dRkWVxuaRlkWBL5KBNoUJrvv0IP5+7TjGDOjCr55fytn/O41p7+nX4dJyKPBFGqF357bc+dUS7vl/J5J05yt3z2bSn+fw4RYNyia5T4Ev0gSnH92VF649jX87azCvvb+ZM25+lV8+v4Sy8sqoSxOplwJfpIla5Sf49ulH8fL3x3P+8B7c/upyTv/NK9w/cxVVGntfcpACX+QgdevQmt9eMpy/XXUKA4rb8Z9PLeTcP7yu/n3JOQp8kWYyrGcHHv7Gydz6pePZVVnFV+6ezdfumc0HmzRGj+QG/fBKJATlldXc+8ZK/vjPD9hZUcU5Q7vzzXEDGNazQ9SlySFGv7QVyRGbd+zmrtdXcP+MVWzfXcWpA7tw5bgBjB7QWWP0SLNQ4IvkmLLySh6Y+SF3vb6CzTt2M7xnB64cP4Azj+1GnsbokYOgwBfJUeWV1Tw+bw2Tpy1n1ZZd9C8u2jNaZ9tCDV4rjafAF8lx1Unn+YXrue3VZSxcW0b71vl8/oRefHl0H/p1KYq6PGlBFPgiLYS7M+/Dj7jvjVU89856qpLOuEHFfHVMH8YN6qohmeWAcibwzWwlsB2oBqoOVJQCX+JsU1k5D81ezYOzV7GxbDe9O7Xlyyf34QujenFY64Koy5MclWuBX+LumzNZXoEvApXVSV5ctJH7Zqxk9oqttGuVz2WjevG1sf3o0bFNRu/x3sbtvPb+ZgYf0Z6T+neiIKGf3ByqdIlDkRasIJHHucd159zjurNw7TbueG05d09fyT3TV3L+8B58/dT+HNvjsP1e9/7G7TyzYD3PvbOe9zft2DO/Q5sCPnV0V84c0o1xg4ppU5jI5p8jOSTsFv4K4CPAgdvdfXJDy6uFL1K3NR/t4u7XV/LXNz9kV0U1pw7swqTT+tPtsNY8+04q5N/buAMzGNW3E+ce153TB3dl8foyXli0gX8s2cS2TyppXZDHaQOLOWtINz51TFc6ti2M+k+Tg5RLXTo93H2dmXUFXgK+4+7Tai0zCZgE0Lt37xNWrVoVWj0iLd22XZU8MHsV90xfSen23QCYwYl9O3Hecd05e0g3uh7Wer/XVVYnmb1iKy8s2sCLizayoaycgoQxblAxE0ccyRnHHKGWfwsVSuCbWZG77zyIom4Edrj7TfUtoxa+SGZ2V1Xz7IL17Kyo5qxjj6gz5OuTTDoL1m7juXfWM2X+OjaUlVNUmOCsId24YEQPTjmqC/nq828xmjXwzWwMcCfQzt17m9lw4Bvu/q0DvK4IyHP37cH0S8BP3X1qfa9R4ItkV3XSmb1iK0/PX8tz76ynrLyKzkWFnHdcdyaOPJKRvTpqCIgc19yBPwu4GJji7iODeQvdfegBXtcfeDJ4mA886O6/aOg1CnyR6OyuquaVd0t5ev5a/r5kExVVSfp1KeLCEUdy0cgj6d25bdQlSh2a/Swdd19day9fncFrlgPDM3l/EYleq/xUt85ZQ7pRVl7J1Hc28MRba/jd39/jd39/j5I+h3PR8Udy3rAedGir3wW0RJkE/uqgW8fNrBC4GlgSblkiEqXDWhdwyYm9uOTEXqz9+BOeemstT761lh8+uZCfTFnMhKO7cv7wHpx+dLHGAGpBMunS6QL8HjgDMOBF4Gp339rcxahLRyR3uTuL1pXxxLy1THl7HZt37KZNQYIJx3TlvGHdGT+4q870iUBz9+GPdffpB5rXHBT4Ii1DzT97n31nHVMXbmDzjgraFiY445gjOPe47owbVEzrAoV/NjR34M9z9+MPNK85KPBFWp6q4Bz/Z95Zz9SFG9i6s4KiwgSnH92VzwztzvjBxRS1UrdPWJrln7ZmNhoYAxSb2XVpTx0GaNctIgDkJ/IYc1QXxhzVhZ9eMIQZy7fw/MINvLhoA88sWE+r/DzGDSrmM8O68aljjtBAcBFqaLdbCLQLlmmfNr+M1GmaIiL7yE/kcerAYk4dWMzPJg7lzZVbmbpwA1MXbuDFxRspSBhjBnShR8c2FCaMgkQeBfl5FOTtne7QpoAT+x7OgOJ2+g1AM8ukS6ePu2dlvAN16YgcmpJJZ/6aj5m6cAP/XJoa16eyOkllVZLKaqeiOrnfa4rbt2J0/86MHtCZ0f0706dzW+0A6tDcffjFwPXAEGDP77fdfcLBFFkXBb5IPLk71UmnstrZWFbOzOVbmLF8CzOWbWFTMGZQ9w6tGd2/Myf378yofp20Awg09w+vHgAeBs4Dvgl8FShtenkiIvsyM/ITRn4C+nYpom+XIi4d1Rt3Z/nmncxYltoBvPpeKU+8tRaAIw5rxah+qfA/qV8nBnZVF9CBZNLCn+vuJ5jZAnc/Lpj3qruPa+5i1MIXkYa4O8tKdzBz+VZmr9jKrBVb2FiWOgLoVFTIiX0P5+SgG2hQ1/bkxeASkc3dwq8M7teb2bnAOqBnU4sTEWkqM+Ooru05qmt7Lj+5D+7Oh1t3MWvFVmYtT+0AXli0EUjtAE7q12nPDkBHAJkF/s/NrAPwPeAWUqdlXhtqVSIiGTAz+nQuok/nIi4p6QXA6q27mLl8CzOXb2VmcIooQOeiQk4d2IXffWFEbIO/wcA3swQw0N2fAbYBp2elKhGRJurVqS29OrXl8yW9cHfWfPQJM5ZtYebyLeyuTsY27OEAge/u1WZ2AfC7LNUjItJszGzPDuCSE3tFXU7kMunSecPM/kjqTJ09V7xy93mhVSUiIs0uk8AfE9z/NG2eA81+Hr6IiITngIHv7uq3FxE5BOhKxSIiMRF64JtZwszeMrNnwl6XiIjULxst/O+iSyKKiETugH34ZvbZOmZvA95x900HeG1P4FzgF8B1DS0rIiLhyuQsnSuA0cDLwePxwExgkJn91N3/0sBr/5fUSJvt61vAzCYBkwB69+6dQTkiItIUmXTpJIFj3P1z7v454FhgN3AS8O/1vcjMzgM2ufvcht7c3Se7e4m7lxQXFzeidBERaYxMAr+vu29Me7wJGOTuW9k7sFpdxgIXmNlK4K/ABDO7v8mViojIQcmkS+e14AybR4PHnwOmmVkR8HF9L3L3HwA/ADCz8cD33f3ygytXRESaKpPA/zapkB8LGPBn4HFPDaSvH2WJiLQQmfzS1oHHgluTuPsrwCtNfb2IiBy8A/bhm9lnzex9M9tmZmVmtt3MyrJRnIiINJ9MunR+DZzv7vrxlIhIC5bJWTobFfYiIi1fJi38OWb2MPAUqfPvAXD3J0KrSkREml0mgX8YsAs4M22eAwp8EZEWJJOzdL6WjUJERCRc9Qa+mV3v7r82s1tItej34e5Xh1qZiIg0q4Za+DX/qJ2TjUJERCRc9Qa+u/8tuL8ve+WIiEhYMhkPfxDwfaBv+vLurouYi4i0IJmcpfMocBtwJ1AdbjkiIhKWTAK/yt1vDb0SEREJVSa/tP2bmX3LzLqbWaeaW+iViYhIs8qkhf/V4P7f0uY50L/5yxERkbA0GPhmlgdc7u7Ts1SPiIiEpMEuHXdPAjdlqRYREQlRJn34L5rZ58zMQq9GRERCk0kf/nVAEVBlZuWkLnPo7n5YQy8ys9bANKBVsJ7H3P3HB1mviIg0USaDp7Vv4nvvBia4+w4zKwBeN7Pn3X1mE99PREQOQiYtfMzscGAg0LpmnrtPa+g1wbVwdwQPC4LbfoOwiYhIdmQytMK/At8FegLzgZOBGcABh1YwswQwFzgK+JO7z6pjmUnAJIDevXs3pnYREWmETP5p+13gRGCVu58OjARKM3lzd6929xGkdhajzGxoHctMdvcSdy8pLi5uROkiItIYmQR+ubuXA5hZK3dfCgxuzErc/WPgFeDsRlcoIiLNIpPAX2NmHUld0/YlM3saWHegF5lZcfA6zKwNcAaw9GCKFRGRpsvkLJ2LgskbzexloAMwNYP37g7cF/Tj5wGPuPszTa5UREQOSqZn6ZwCDHT3e8ysGDgSWNHQa9x9Aan+fhERyQEH7NIxsx8D/w78IJhVANwfZlEiItL8MunDvwi4ANgJ4O7rgKb+GEtERCKSSeBXBD+icgAzKwq3JBERCUMmgf+Imd0OdDSzrwN/B+4ItywREWlumZylc5OZfRooI3X+/X+5+0uhVyYiIs0qo7N0goBXyIuItGD1Br6Zbafuwc4yGh5ZRERyS72BfxDDIouISA7K5J+2IiJyCFDgi4jEhAJfRCQmFPgiIjGhwBcRiQkFvohITCjwRURiQoEvIhITCnwRkZgILfDNrJeZvWxmS8xskZl9N6x1iYjIgWU0eFoTVQHfc/d5ZtYemGtmL7n74hDXKSIi9Qithe/u6919XjC9HVhC6lq4IiISgaz04ZtZX1IXNJ+VjfWJiMj+Qg98M2sHPA5c4+5ldTw/yczmmNmc0tLSsMsREYmtUAPfzApIhf0D7v5EXcu4+2R3L3H3kuLi4jDLERGJtTDP0jHgLmCJu98c1npERCQzYbbwxwJfBiaY2fzgdk6I6xMRkQaEdlqmu79O6nKIIiKSA/RLWxGRmFDgi4jEhAJfRCQmFPgiIjGhwBcRiQkFvohITCjwRURiQoEvIhITCnwRkZhQ4IuIxIQCX0QkJhT4IiIxocAXEYkJBb6ISEwo8EVEYkKBLyISEwp8EZGYCPOatneb2SYzWxjWOkREJHNhtvDvBc4O8f1FRKQRQgt8d58GbA3r/UVEpHHUhy8iEhORB76ZTTKzOWY2p7S0NOpyREQOWZEHvrtPdvcSdy8pLi6OuhwRkUNW5IEvIiLZEeZpmQ8BM4DBZrbGzK4Ia10iInJg+WG9sbtfFtZ7i4hI46lLR0QkJhT4IiIxocAXEYkJBb6ISEwo8EVEYkKBLyISEwp8EZGYUOCLiMSEAl9EJCYU+CIiMaHAFxGJCQW+iEhMKPBFRGJCgS8iEhMKfBGRmFDgi4jEhAJfRCQmQrviFYCZnQ38HkgAd7r7r8JcX502LoZV08GTYMH+zfLALHgc3OclwBLBvaVNJyAvPzWdVzOdnzavZrog9ThRsPdxouY+mGeW9T9fJDbcoboSkpXBfVXqtmde1d55ySpIVu//2KvruE+m7j2ZdvP9H+O17qk1XUt6HhS0hZKvhbyBQgx8M0sAfwI+DawB3jSzKe6+OKx17lH6Hix6AhY9CaVLQ19dxvIKIFGY2hEkCoNbQQbThZDfKpgX3Oe3Sk3nF9a6T1tun3mFte5rPa+dkWQimYTqCqjeDVUVUFWeely1e++8fe53pz2ftlx15d55e+ZXBs9VBNMVwftU7Dsvfbom3GtCvaUq6tqyAx8YBXzg7ssBzOyvwEQgnMDfsiwI+adg40LAoM8YOOcmGHQ2FBbtv2cmbTpZve997T18TSvAq/dtHexpSVTX07KoqvvDWfOhTVbW/WHevX3fL8GeD37al8Wrm2/75RXUv2NI1PHcPjuhwlrLBfNqjm722YkV7HvUs2fZ/L3P1Rw97TlaquOoyhKQ18J7JGtajnW1Ouv6HO1ppVbu/9na77OUPl07RNPDtbJWWKcHcB0Bnqxqvr8/L3//z9eez1badGFbSBy+9/OTaLVvoykvvQGVflSdftRdMy+R9hmrfZSe2Pu52ucIPxH0CgQ9ATW9AvvcbG8PAgQNKNt/eo96Wv0hCzPwjwRWpz1eA5zU7Gup2An3fAbWv5163OskOPt/4NiJcFj3Zl9dTklW19FKakJr60DL1UxXlUP5trQdU+0WWHDLGqu1A0jrjkv/MtZ8YbHU927PtO07vd+XMl3tw/U67pPph/jpt7RugZpAb86ddSb2OfpL3wHXOkps1a6enXx9R5OFkN+6gefqeY9EYcvfYbdAYQZ+Xd+e/XZrZjYJmATQu3fvxq+lsAi6DYNhl8CQC6FDz8a/R0uVl0i1fmgbdSV7uQctz/RD71pHMbVbqnX1u+7Tuq3e28JN1upj3XPUVTO/jrD1ZCpw9wnpZB3T9f5R7LNTqOt+vxZfXt07nf12UGn/G0oU1Gp1FtR6rvb/hYLHNQG935FVEOg1O0KJvTADfw3QK+1xT2Bd7YXcfTIwGaCkpKRpxzkT/9Skl0kIzPaGDUVRVyMiacI8pnoTGGhm/cysELgUmBLi+kREpAGhtfDdvcrMrgJeIHVa5t3uviis9YmISMNCPQ/f3Z8DngtzHSIikhn9m1xEJCYU+CIiMaHAFxGJCQW+iEhMKPBFRGLCvMFfGGaXmZUCq+p5uguwOYvlZEp1NY7qahzV1Xi5WltYdfVx9+JMFsypwG+Imc1x95Ko66hNdTWO6moc1dV4uVpbLtSlLh0RkZhQ4IuIxERLCvzJURdQD9XVOKqrcVRX4+VqbZHX1WL68EVE5OC0pBa+iIgchBYR+GZ2tpm9a2YfmNkNUddTw8xWmtk7ZjbfzOZEWMfdZrbJzBamzetkZi+Z2fvB/eE5UteNZrY22GbzzeycCOrqZWYvm9kSM1tkZt8N5ke6zRqoK9JtZmatzWy2mb0d1PWTYH4/M5sVbK+Hg2HQc6Gue81sRdr2GpHNutLqS5jZW2b2TPA40u0FgLvn9I3U0MrLgP5AIfA2cGzUdQW1rQS65EAdpwHHAwvT5v0auCGYvgH4nxyp60bg+xFvr+7A8cF0e+A94Niot1kDdUW6zUhdva5dMF0AzAJOBh4BLg3m3wZcmSN13QtcHOVnLKjpOuBB4JngcaTby91bRAt/z8XQ3b0CqLkYugTcfRqwtdbsicB9wfR9wIVZLYp664qcu69393nB9HZgCalrMEe6zRqoK1KesiN4WBDcHJgAPBbMj2J71VdX5MysJ3AucGfw2Ih4e0HL6NKp62LokX8JAg68aGZzg2vz5pIj3H09pIIE6BpxPemuMrMFQZdP1rua0plZX2AkqdZhzmyzWnVBxNss6J6YD2wCXiJ11P2xu1cFi0Tyvaxdl7vXbK9fBNvrd2bWKtt1Af8LXA8kg8edyYHt1RICP6OLoUdkrLsfD3wG+LaZnRZ1QS3ArcAAYASwHvhtVIWYWTvgceAady+Lqo7a6qgr8m3m7tXuPoLUtalHAcfUtVh2q9q/LjMbCvwAOBo4EegE/Hs2azKz84BN7j43fXYdi2Z9e7WEwM/oYuhRcPd1wf0m4ElSX4RcsdHMugME95sirgcAd98YfEmTwB1EtM3MrIBUqD7g7k8EsyPfZnXVlSvbLKjlY+AVUn3lHc2s5qp5kX4v0+o6O+gac3ffDdxD9rfXWOACM1tJqgt6AqkWf+TbqyUEfk5eDN3Misysfc00cCawsOFXZdUU4KvB9FeBpyOsZY+aQA1cRATbLOhPvQtY4u43pz0V6Tarr66ot5mZFZtZx2C6DXAGqf8vvAxcHCwWxfaqq66laTttI9VPntXt5e4/cPee7t6XVF79092/RMTbq6a4nL8B55A6Y2EZ8MOo6wlq6k/qjKG3gUVR1gU8ROpQv5LUEdEVpPoM/wG8H9x3ypG6/gK8AywgFbDdI6jrFFKH0wuA+cHtnKi3WQN1RbrNgOOAt4L1LwT+K5jfH5gNfAA8CrTKkbr+GWyvhcD9BGfyRHEDxrP3LJ1It5e765e2IiJx0RK6dEREpBko8EVEYkKBLyISEwp8EZGYUOCLiMSEAl8OGWb2RnDf18y+2Mzv/R91rUukJdFpmXLIMbPxpEaXPK8Rr0m4e3UDz+9w93bNUZ9IVNTCl0OGmdWMnPgr4NRgLPRrgwG2fmNmbwYDan0jWH58MP78g6R+qIOZPRUMhreoZkA8M/sV0CZ4vwfS12UpvzGzhZa6NsIX0t77FTN7zMyWmtkDwS8/MbNfmdnioJabsrmNJN7yD7yISItzA2kt/CC4t7lSeNAeAAABoUlEQVT7icHIidPN7MVg2VHAUHdfETz+F3ffGvxU/00ze9zdbzCzqzw1SFdtnyU1qNlwoEvwmmnBcyOBIaTGTJkOjDWzxaSGRzja3b1maACRbFALX+LgTOArwTC6s0gNoTAweG52WtgDXG1mbwMzSQ3aN5CGnQI85KnBzTYCr5IapbHmvdd4atCz+UBfoAwoB+40s88Cuw76rxPJkAJf4sCA77j7iODWz91rWvg79yyU6vs/Axjt7sNJjdPSOoP3rs/utOlqIN9T46GPIjUi5oXA1Eb9JSIHQYEvh6LtpC4RWOMF4Mpg6GHMbFAwwmltHYCP3H2XmR1NagjgGpU1r69lGvCF4P8ExaQu6zi7vsKCse47uPtzwDWkuoNEskJ9+HIoWgBUBV0z9wK/J9WdMi/4x2kpdV9ebirwTTNbALxLqlunxmRggZnN89RQtzWeBEaTGjXVgevdfUOww6hLe+BpM2tN6ujg2qb9iSKNp9MyRURiQl06IiIxocAXEYkJBb6ISEwo8EVEYkKBLyISEwp8EZGYUOCLiMSEAl9EJCb+P5UqQ5REQ3SuAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 128x128\n", "if args.train_128:\n", " save_dir = args.save_dir+'/128'\n", " update_model_dir(learner, save_dir)\n", " sargs = save_args('first_run_128', save_dir)\n", " learner.fit(args.lr,args.epochs, cycle_len=args.cycle_len,\n", " train_sampler=train_sampler,\n", " wds=args.weight_decay,\n", " use_clr_beta=args.use_clr,\n", " loss_scale=args.loss_scale,\n", " **sargs\n", " )\n", " save_sched(learner.sched, save_dir)\n", " data, train_sampler = fast_loader(args.data, args.sz)\n", " learner.set_data(data)\n", "\n", "\n", "# Full size\n", "update_model_dir(learner, args.save_dir)\n", "sargs = save_args('first_run', args.save_dir)\n", "learner.fit(args.lr,args.epochs, cycle_len=args.cycle_len,\n", " train_sampler=train_sampler,\n", " wds=args.weight_decay,\n", " use_clr_beta=args.use_clr,\n", " loss_scale=args.loss_scale,\n", " **sargs\n", " )\n", "save_sched(learner.sched, args.save_dir)\n", "\n", "if args.use_tta:\n", " print(accuracy(*learner.TTA()))\n", "\n", "print('Finished!')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# Taken from main.py topk accuracy\n", "def top5old(output, target):\n", " \"\"\"Computes the precision@k for the specified values of k\"\"\"\n", " topk = 5\n", " batch_size = target.size(0)\n", " _, pred = output.topk(topk, 1, True, True)\n", " pred = pred.t()\n", " correct = pred.eq(target.view(1, -1).expand_as(pred))\n", " res = []\n", " correct_k = correct[:5].view(-1).float().sum(0, keepdim=True)\n", " return correct_k.mul_(100.0 / batch_size)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "nav_menu": { "height": "266px", "width": "252px" }, "number_sections": true, "sideBar": true, "skip_h1_title": false, "toc_cell": false, "toc_position": {}, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }