{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Fine-tune with Pre-trained Models\n", "\n", "Many of the exciting deep learning algorithms for computer vision \n", "require massive datasets for training. \n", "The most popular benchmark dataset, [ImageNet](http://www.image-net.org/), for example,\n", "contains one million images from one thousand categories.\n", "But for any practical problem, we typically have access to comparatively small datasets.\n", "In these cases, if we were to train a neural network's weights from scratch, \n", "starting from random initialized parameters, we would overfit the training set badly. \n", "\n", "One approach to get around this problem is to first pretrain a deep net on a large-scale dataset, like ImageNet.\n", "Then, given a new dataset, we can start with these pretrained weights when training on our new task.\n", "This process commonly called \"fine-tuning\". \n", "There are anumber of variations of fine-tuning. \n", "Sometimes, the initial neural network is used only as a _feature extractor_. \n", "That means that we freeze every layer prior to the output layer and simply learn a new output layer. \n", "In [another document](./predict.ipynb), \n", "we explained how to do this kind of feature extraction.\n", "Another approach is to update all of networks weights for the new task, \n", "and that's the appraoch we demonstrate in this document.\n", "\n", "To fine-tune a network, we must first replace the last fully-connected layer \n", "with a new one that outputs the desired number of classes. \n", "We initialize its weights randomly.\n", "Then we continue training as normal. \n", "Sometimes it's common use a smaller learning rate \n", "based on the intuition that we may already be close to a good result. \n", "\n", "In this demonstration, we'll fine-tune a model pre-trained on ImageNet to the smaller caltech-256 dataset. \n", "Following this example, you can finetune to other datasets, even for strikingly different applications such as face identification. \n", "\n", "We will show that, even with simple hyper-parameters setting, we can match and even outperform state-of-the-art results on caltech-256.\n", "\n", "| Network | Accuracy | \n", "| --- | --- | \n", "| Resnet-50 | 77.4% | \n", "| Resnet-152 | 86.4% | \n", "\n", "## Prepare data\n", "\n", "We follow the standard protocol to sample 60 images from each class as the training set, and the rest for the validation set. We resize images into 256x256 size and pack them into the rec file. The scripts to prepare the data is as following. \n", "\n", "\n", "```sh\n", "wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar\n", "tar -xf 256_ObjectCategories.tar\n", "\n", "mkdir -p caltech_256_train_60\n", "for i in 256_ObjectCategories/*; do\n", " c=`basename $i`\n", " mkdir -p caltech_256_train_60/$c\n", " for j in `ls $i/*.jpg | shuf | head -n 60`; do\n", " mv $j caltech_256_train_60/$c/\n", " done\n", "done\n", "\n", "python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-train caltech_256_train_60/\n", "python ~/mxnet/tools/im2rec.py --list True --recursive True caltech-256-60-val 256_ObjectCategories/\n", "python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-val 256_ObjectCategories/\n", "python ~/mxnet/tools/im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-train caltech_256_train_60/\n", "```\n", "\n", "The following codes download the pre-generated rec files. It may take a few minutes." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "import os, urllib\n", "def download(url):\n", " filename = url.split(\"/\")[-1]\n", " if not os.path.exists(filename):\n", " urllib.urlretrieve(url, filename)\n", "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')\n", "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "deletable": true, "editable": true }, "source": [ "Next we define the function which returns the data iterators:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "import mxnet as mx\n", "\n", "def get_iterators(batch_size, data_shape=(3, 224, 224)):\n", " train = mx.io.ImageRecordIter(\n", " path_imgrec = './caltech-256-60-train.rec',\n", " data_name = 'data',\n", " label_name = 'softmax_label',\n", " batch_size = batch_size,\n", " data_shape = data_shape,\n", " shuffle = True,\n", " rand_crop = True,\n", " rand_mirror = True)\n", " val = mx.io.ImageRecordIter(\n", " path_imgrec = './caltech-256-60-val.rec',\n", " data_name = 'data',\n", " label_name = 'softmax_label',\n", " batch_size = batch_size,\n", " data_shape = data_shape,\n", " rand_crop = False,\n", " rand_mirror = False)\n", " return (train, val)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "We then download a pretrained 50-layer ResNet model and load into memory. Note that if `load_checkpoint` reports an error, we can remove the downloaded files and try `get_model` again." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "def get_model(prefix, epoch):\n", " download(prefix+'-symbol.json')\n", " download(prefix+'-%04d.params' % (epoch,))\n", "\n", "get_model('http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50', 0)\n", "sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Train\n", "\n", "We first define a function which replaces the the last fully-connected layer for a given network. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):\n", " \"\"\"\n", " symbol: the pre-trained network symbol\n", " arg_params: the argument parameters of the pre-trained model\n", " num_classes: the number of classes for the fine-tune datasets\n", " layer_name: the layer name before the last fully-connected layer\n", " \"\"\"\n", " all_layers = sym.get_internals()\n", " net = all_layers[layer_name+'_output']\n", " net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')\n", " net = mx.symbol.SoftmaxOutput(data=net, name='softmax')\n", " new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})\n", " return (net, new_args)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Now we create a module. We pass the argument parameters of the pre-trained model to replace all parameters except for the last fully-connected layer. For the last fully-connected layer, we use an initializer to initialize. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "import logging\n", "head = '%(asctime)-15s %(message)s'\n", "logging.basicConfig(level=logging.DEBUG, format=head)\n", "\n", "def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):\n", " devs = [mx.gpu(i) for i in range(num_gpus)]\n", " mod = mx.mod.Module(symbol=new_sym, context=devs)\n", " mod.fit(train, val, \n", " num_epoch=8,\n", " arg_params=arg_params,\n", " aux_params=aux_params,\n", " allow_missing=True,\n", " batch_end_callback = mx.callback.Speedometer(batch_size, 10), \n", " kvstore='device',\n", " optimizer='sgd',\n", " optimizer_params={'learning_rate':0.01},\n", " initializer=mx.init.Xavier(rnd_type='gaussian', factor_type=\"in\", magnitude=2),\n", " eval_metric='acc')\n", " metric = mx.metric.Accuracy()\n", " return mod.score(val, metric)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Then we can start training. We use AWS EC2 g2.8xlarge, which has 8 GPUs." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2016-10-22 18:24:16,695 Already binded, ignoring bind()\n", "2016-10-22 18:24:22,361 Epoch[0] Batch [10]\tSpeed: 325.98 samples/sec\tTrain-accuracy=0.004261\n", "2016-10-22 18:24:26,205 Epoch[0] Batch [20]\tSpeed: 333.06 samples/sec\tTrain-accuracy=0.011719\n", "2016-10-22 18:24:30,072 Epoch[0] Batch [30]\tSpeed: 331.06 samples/sec\tTrain-accuracy=0.021094\n", "2016-10-22 18:24:33,954 Epoch[0] Batch [40]\tSpeed: 329.84 samples/sec\tTrain-accuracy=0.020313\n", "2016-10-22 18:24:37,811 Epoch[0] Batch [50]\tSpeed: 331.93 samples/sec\tTrain-accuracy=0.023438\n", "2016-10-22 18:24:41,668 Epoch[0] Batch [60]\tSpeed: 331.93 samples/sec\tTrain-accuracy=0.032813\n", "2016-10-22 18:24:45,557 Epoch[0] Batch [70]\tSpeed: 329.22 samples/sec\tTrain-accuracy=0.049219\n", "2016-10-22 18:24:49,424 Epoch[0] Batch [80]\tSpeed: 331.12 samples/sec\tTrain-accuracy=0.071875\n", "2016-10-22 18:24:53,323 Epoch[0] Batch [90]\tSpeed: 328.36 samples/sec\tTrain-accuracy=0.084375\n", "2016-10-22 18:24:57,203 Epoch[0] Batch [100]\tSpeed: 329.95 samples/sec\tTrain-accuracy=0.115625\n", "2016-10-22 18:25:01,091 Epoch[0] Batch [110]\tSpeed: 329.33 samples/sec\tTrain-accuracy=0.153906\n", "2016-10-22 18:25:05,000 Epoch[0] Batch [120]\tSpeed: 327.49 samples/sec\tTrain-accuracy=0.187500\n", "2016-10-22 18:25:05,001 Epoch[0] Train-accuracy=nan\n", "2016-10-22 18:25:05,002 Epoch[0] Time cost=48.301\n", "2016-10-22 18:25:24,502 Epoch[0] Validation-accuracy=0.297072\n", "2016-10-22 18:25:28,564 Epoch[1] Batch [10]\tSpeed: 330.58 samples/sec\tTrain-accuracy=0.240767\n", "2016-10-22 18:25:32,426 Epoch[1] Batch [20]\tSpeed: 331.53 samples/sec\tTrain-accuracy=0.265625\n", "2016-10-22 18:25:36,289 Epoch[1] Batch [30]\tSpeed: 331.41 samples/sec\tTrain-accuracy=0.287500\n", "2016-10-22 18:25:40,173 Epoch[1] Batch [40]\tSpeed: 329.64 samples/sec\tTrain-accuracy=0.314063\n", "2016-10-22 18:25:44,032 Epoch[1] Batch [50]\tSpeed: 331.80 samples/sec\tTrain-accuracy=0.361719\n", "2016-10-22 18:25:47,876 Epoch[1] Batch [60]\tSpeed: 333.07 samples/sec\tTrain-accuracy=0.347656\n", "2016-10-22 18:25:51,741 Epoch[1] Batch [70]\tSpeed: 331.30 samples/sec\tTrain-accuracy=0.410156\n", "2016-10-22 18:25:55,603 Epoch[1] Batch [80]\tSpeed: 331.50 samples/sec\tTrain-accuracy=0.417187\n", "2016-10-22 18:25:59,460 Epoch[1] Batch [90]\tSpeed: 331.88 samples/sec\tTrain-accuracy=0.425781\n", "2016-10-22 18:26:03,304 Epoch[1] Batch [100]\tSpeed: 333.11 samples/sec\tTrain-accuracy=0.419531\n", "2016-10-22 18:26:07,196 Epoch[1] Batch [110]\tSpeed: 328.97 samples/sec\tTrain-accuracy=0.496875\n", "2016-10-22 18:26:10,665 Epoch[1] Train-accuracy=0.488715\n", "2016-10-22 18:26:10,666 Epoch[1] Time cost=46.163\n", "2016-10-22 18:26:29,719 Epoch[1] Validation-accuracy=0.556066\n", "2016-10-22 18:26:33,883 Epoch[2] Batch [10]\tSpeed: 325.12 samples/sec\tTrain-accuracy=0.514915\n", "2016-10-22 18:26:37,757 Epoch[2] Batch [20]\tSpeed: 330.50 samples/sec\tTrain-accuracy=0.524219\n", "2016-10-22 18:26:41,684 Epoch[2] Batch [30]\tSpeed: 325.98 samples/sec\tTrain-accuracy=0.536719\n", "2016-10-22 18:26:45,562 Epoch[2] Batch [40]\tSpeed: 330.21 samples/sec\tTrain-accuracy=0.514844\n", "2016-10-22 18:26:49,448 Epoch[2] Batch [50]\tSpeed: 329.44 samples/sec\tTrain-accuracy=0.564844\n", "2016-10-22 18:26:53,338 Epoch[2] Batch [60]\tSpeed: 329.16 samples/sec\tTrain-accuracy=0.534375\n", "2016-10-22 18:26:57,230 Epoch[2] Batch [70]\tSpeed: 328.99 samples/sec\tTrain-accuracy=0.576562\n", "2016-10-22 18:27:01,128 Epoch[2] Batch [80]\tSpeed: 328.42 samples/sec\tTrain-accuracy=0.604688\n", "2016-10-22 18:27:04,990 Epoch[2] Batch [90]\tSpeed: 331.54 samples/sec\tTrain-accuracy=0.582812\n", "2016-10-22 18:27:08,874 Epoch[2] Batch [100]\tSpeed: 329.63 samples/sec\tTrain-accuracy=0.572656\n", "2016-10-22 18:27:12,737 Epoch[2] Batch [110]\tSpeed: 331.45 samples/sec\tTrain-accuracy=0.625781\n", "2016-10-22 18:27:16,591 Epoch[2] Batch [120]\tSpeed: 332.20 samples/sec\tTrain-accuracy=0.603125\n", "2016-10-22 18:27:16,597 Epoch[2] Train-accuracy=nan\n", "2016-10-22 18:27:16,598 Epoch[2] Time cost=46.878\n", "2016-10-22 18:27:34,905 Epoch[2] Validation-accuracy=0.651947\n", "2016-10-22 18:27:38,961 Epoch[3] Batch [10]\tSpeed: 330.53 samples/sec\tTrain-accuracy=0.636364\n", "2016-10-22 18:27:42,811 Epoch[3] Batch [20]\tSpeed: 332.56 samples/sec\tTrain-accuracy=0.634375\n", "2016-10-22 18:27:46,675 Epoch[3] Batch [30]\tSpeed: 331.38 samples/sec\tTrain-accuracy=0.629687\n", "2016-10-22 18:27:50,545 Epoch[3] Batch [40]\tSpeed: 330.79 samples/sec\tTrain-accuracy=0.641406\n", "2016-10-22 18:27:54,423 Epoch[3] Batch [50]\tSpeed: 330.16 samples/sec\tTrain-accuracy=0.665625\n", "2016-10-22 18:27:58,273 Epoch[3] Batch [60]\tSpeed: 332.54 samples/sec\tTrain-accuracy=0.638281\n", "2016-10-22 18:28:02,131 Epoch[3] Batch [70]\tSpeed: 331.93 samples/sec\tTrain-accuracy=0.671875\n", "2016-10-22 18:28:05,988 Epoch[3] Batch [80]\tSpeed: 331.88 samples/sec\tTrain-accuracy=0.691406\n", "2016-10-22 18:28:09,870 Epoch[3] Batch [90]\tSpeed: 329.84 samples/sec\tTrain-accuracy=0.670312\n", "2016-10-22 18:28:13,742 Epoch[3] Batch [100]\tSpeed: 330.65 samples/sec\tTrain-accuracy=0.660156\n", "2016-10-22 18:28:17,636 Epoch[3] Batch [110]\tSpeed: 328.77 samples/sec\tTrain-accuracy=0.681250\n", "2016-10-22 18:28:21,097 Epoch[3] Train-accuracy=0.684028\n", "2016-10-22 18:28:21,098 Epoch[3] Time cost=46.192\n", "2016-10-22 18:28:40,464 Epoch[3] Validation-accuracy=0.701943\n", "2016-10-22 18:28:44,610 Epoch[4] Batch [10]\tSpeed: 327.03 samples/sec\tTrain-accuracy=0.708807\n", "2016-10-22 18:28:48,480 Epoch[4] Batch [20]\tSpeed: 330.86 samples/sec\tTrain-accuracy=0.708594\n", "2016-10-22 18:28:52,371 Epoch[4] Batch [30]\tSpeed: 329.02 samples/sec\tTrain-accuracy=0.713281\n", "2016-10-22 18:28:56,234 Epoch[4] Batch [40]\tSpeed: 331.46 samples/sec\tTrain-accuracy=0.700781\n", "2016-10-22 18:29:00,129 Epoch[4] Batch [50]\tSpeed: 328.65 samples/sec\tTrain-accuracy=0.712500\n", "2016-10-22 18:29:04,006 Epoch[4] Batch [60]\tSpeed: 330.30 samples/sec\tTrain-accuracy=0.697656\n", "2016-10-22 18:29:07,865 Epoch[4] Batch [70]\tSpeed: 331.74 samples/sec\tTrain-accuracy=0.717969\n", "2016-10-22 18:29:11,737 Epoch[4] Batch [80]\tSpeed: 330.61 samples/sec\tTrain-accuracy=0.737500\n", "2016-10-22 18:29:15,592 Epoch[4] Batch [90]\tSpeed: 332.19 samples/sec\tTrain-accuracy=0.714844\n", "2016-10-22 18:29:19,435 Epoch[4] Batch [100]\tSpeed: 333.15 samples/sec\tTrain-accuracy=0.696875\n", "2016-10-22 18:29:23,287 Epoch[4] Batch [110]\tSpeed: 332.35 samples/sec\tTrain-accuracy=0.734375\n", "2016-10-22 18:29:27,136 Epoch[4] Batch [120]\tSpeed: 332.61 samples/sec\tTrain-accuracy=0.726562\n", "2016-10-22 18:29:27,137 Epoch[4] Train-accuracy=nan\n", "2016-10-22 18:29:27,138 Epoch[4] Time cost=46.673\n", "2016-10-22 18:29:45,791 Epoch[4] Validation-accuracy=0.736935\n", "2016-10-22 18:29:49,873 Epoch[5] Batch [10]\tSpeed: 332.48 samples/sec\tTrain-accuracy=0.749290\n", "2016-10-22 18:29:53,765 Epoch[5] Batch [20]\tSpeed: 328.95 samples/sec\tTrain-accuracy=0.732031\n", "2016-10-22 18:29:57,648 Epoch[5] Batch [30]\tSpeed: 329.67 samples/sec\tTrain-accuracy=0.736719\n", "2016-10-22 18:30:01,540 Epoch[5] Batch [40]\tSpeed: 329.42 samples/sec\tTrain-accuracy=0.722656\n", "2016-10-22 18:30:05,433 Epoch[5] Batch [50]\tSpeed: 328.82 samples/sec\tTrain-accuracy=0.751563\n", "2016-10-22 18:30:09,309 Epoch[5] Batch [60]\tSpeed: 330.37 samples/sec\tTrain-accuracy=0.736719\n", "2016-10-22 18:30:13,198 Epoch[5] Batch [70]\tSpeed: 329.27 samples/sec\tTrain-accuracy=0.771875\n", "2016-10-22 18:30:17,084 Epoch[5] Batch [80]\tSpeed: 329.47 samples/sec\tTrain-accuracy=0.762500\n", "2016-10-22 18:30:20,958 Epoch[5] Batch [90]\tSpeed: 330.43 samples/sec\tTrain-accuracy=0.742969\n", "2016-10-22 18:30:24,858 Epoch[5] Batch [100]\tSpeed: 328.32 samples/sec\tTrain-accuracy=0.770312\n", "2016-10-22 18:30:28,734 Epoch[5] Batch [110]\tSpeed: 330.27 samples/sec\tTrain-accuracy=0.781250\n", "2016-10-22 18:30:32,217 Epoch[5] Train-accuracy=0.757812\n", "2016-10-22 18:30:32,218 Epoch[5] Time cost=46.426\n", "2016-10-22 18:30:51,745 Epoch[5] Validation-accuracy=0.752450\n", "2016-10-22 18:30:55,887 Epoch[6] Batch [10]\tSpeed: 326.48 samples/sec\tTrain-accuracy=0.754261\n", "2016-10-22 18:30:59,754 Epoch[6] Batch [20]\tSpeed: 331.16 samples/sec\tTrain-accuracy=0.768750\n", "2016-10-22 18:31:03,612 Epoch[6] Batch [30]\tSpeed: 331.83 samples/sec\tTrain-accuracy=0.774219\n", "2016-10-22 18:31:07,472 Epoch[6] Batch [40]\tSpeed: 331.66 samples/sec\tTrain-accuracy=0.751563\n", "2016-10-22 18:31:11,326 Epoch[6] Batch [50]\tSpeed: 332.21 samples/sec\tTrain-accuracy=0.777344\n", "2016-10-22 18:31:15,194 Epoch[6] Batch [60]\tSpeed: 331.01 samples/sec\tTrain-accuracy=0.762500\n", "2016-10-22 18:31:19,062 Epoch[6] Batch [70]\tSpeed: 331.03 samples/sec\tTrain-accuracy=0.801562\n", "2016-10-22 18:31:22,938 Epoch[6] Batch [80]\tSpeed: 330.32 samples/sec\tTrain-accuracy=0.788281\n", "2016-10-22 18:31:26,802 Epoch[6] Batch [90]\tSpeed: 331.37 samples/sec\tTrain-accuracy=0.773438\n", "2016-10-22 18:31:30,656 Epoch[6] Batch [100]\tSpeed: 332.24 samples/sec\tTrain-accuracy=0.777344\n", "2016-10-22 18:31:34,555 Epoch[6] Batch [110]\tSpeed: 328.36 samples/sec\tTrain-accuracy=0.791406\n", "2016-10-22 18:31:38,412 Epoch[6] Batch [120]\tSpeed: 331.89 samples/sec\tTrain-accuracy=0.791406\n", "2016-10-22 18:31:38,413 Epoch[6] Train-accuracy=nan\n", "2016-10-22 18:31:38,414 Epoch[6] Time cost=46.668\n", "2016-10-22 18:31:57,459 Epoch[6] Validation-accuracy=0.768382\n", "2016-10-22 18:32:01,634 Epoch[7] Batch [10]\tSpeed: 324.04 samples/sec\tTrain-accuracy=0.789773\n", "2016-10-22 18:32:05,542 Epoch[7] Batch [20]\tSpeed: 327.57 samples/sec\tTrain-accuracy=0.794531\n", "2016-10-22 18:32:09,411 Epoch[7] Batch [30]\tSpeed: 330.90 samples/sec\tTrain-accuracy=0.788281\n", "2016-10-22 18:32:13,311 Epoch[7] Batch [40]\tSpeed: 328.36 samples/sec\tTrain-accuracy=0.778906\n", "2016-10-22 18:32:17,190 Epoch[7] Batch [50]\tSpeed: 330.00 samples/sec\tTrain-accuracy=0.803125\n", "2016-10-22 18:32:21,075 Epoch[7] Batch [60]\tSpeed: 329.54 samples/sec\tTrain-accuracy=0.780469\n", "2016-10-22 18:32:24,934 Epoch[7] Batch [70]\tSpeed: 331.78 samples/sec\tTrain-accuracy=0.779687\n", "2016-10-22 18:32:28,803 Epoch[7] Batch [80]\tSpeed: 330.92 samples/sec\tTrain-accuracy=0.821875\n", "2016-10-22 18:32:32,662 Epoch[7] Batch [90]\tSpeed: 331.79 samples/sec\tTrain-accuracy=0.783594\n", "2016-10-22 18:32:36,515 Epoch[7] Batch [100]\tSpeed: 332.32 samples/sec\tTrain-accuracy=0.802344\n", "2016-10-22 18:32:40,393 Epoch[7] Batch [110]\tSpeed: 330.16 samples/sec\tTrain-accuracy=0.800000\n", "2016-10-22 18:32:43,832 Epoch[7] Train-accuracy=0.782118\n", "2016-10-22 18:32:43,833 Epoch[7] Time cost=46.373\n", "2016-10-22 18:33:01,994 Epoch[7] Validation-accuracy=0.774422\n" ] } ], "source": [ "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n", "num_classes = 256\n", "batch_per_gpu = 16\n", "num_gpus = 8\n", "\n", "(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)\n", "\n", "batch_size = batch_per_gpu * num_gpus\n", "(train, val) = get_iterators(batch_size)\n", "mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)\n", "assert mod_score > 0.77, \"Low training accuracy.\"" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "\n", "As you can see, after only 8 epochs, we can get 78% validation accuracy. This matches the state-of-the-art results training on caltech-256 alone, e.g. [VGG](http://www.robots.ox.ac.uk/~vgg/research/deep_eval/). \n", "\n", "Next, we try to use another pretrained model. This model was trained on the complete Imagenet dataset, which is 10x larger than the Imagenet 1K classes version, and uses a 3x deeper Resnet architecture. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2016-10-22 18:35:42,274 Already binded, ignoring bind()\n", "2016-10-22 18:35:55,659 Epoch[0] Batch [10]\tSpeed: 139.63 samples/sec\tTrain-accuracy=0.070312\n", "2016-10-22 18:36:04,814 Epoch[0] Batch [20]\tSpeed: 139.83 samples/sec\tTrain-accuracy=0.349219\n", "2016-10-22 18:36:13,991 Epoch[0] Batch [30]\tSpeed: 139.49 samples/sec\tTrain-accuracy=0.585156\n", "2016-10-22 18:36:23,163 Epoch[0] Batch [40]\tSpeed: 139.57 samples/sec\tTrain-accuracy=0.642188\n", "2016-10-22 18:36:32,309 Epoch[0] Batch [50]\tSpeed: 139.97 samples/sec\tTrain-accuracy=0.728906\n", "2016-10-22 18:36:41,426 Epoch[0] Batch [60]\tSpeed: 140.41 samples/sec\tTrain-accuracy=0.760156\n", "2016-10-22 18:36:50,531 Epoch[0] Batch [70]\tSpeed: 140.60 samples/sec\tTrain-accuracy=0.778906\n", "2016-10-22 18:36:59,631 Epoch[0] Batch [80]\tSpeed: 140.68 samples/sec\tTrain-accuracy=0.786719\n", "2016-10-22 18:37:08,742 Epoch[0] Batch [90]\tSpeed: 140.51 samples/sec\tTrain-accuracy=0.797656\n", "2016-10-22 18:37:17,857 Epoch[0] Batch [100]\tSpeed: 140.45 samples/sec\tTrain-accuracy=0.823438\n", "2016-10-22 18:37:26,969 Epoch[0] Batch [110]\tSpeed: 140.50 samples/sec\tTrain-accuracy=0.827344\n", "2016-10-22 18:37:36,094 Epoch[0] Batch [120]\tSpeed: 140.29 samples/sec\tTrain-accuracy=0.829688\n", "2016-10-22 18:37:36,095 Epoch[0] Train-accuracy=nan\n", "2016-10-22 18:37:36,096 Epoch[0] Time cost=113.804\n", "2016-10-22 18:38:08,728 Epoch[0] Validation-accuracy=0.829780\n", "2016-10-22 18:38:18,228 Epoch[1] Batch [10]\tSpeed: 139.92 samples/sec\tTrain-accuracy=0.862926\n", "2016-10-22 18:38:27,365 Epoch[1] Batch [20]\tSpeed: 140.10 samples/sec\tTrain-accuracy=0.867969\n", "2016-10-22 18:38:36,476 Epoch[1] Batch [30]\tSpeed: 140.52 samples/sec\tTrain-accuracy=0.884375\n", "2016-10-22 18:38:45,581 Epoch[1] Batch [40]\tSpeed: 140.60 samples/sec\tTrain-accuracy=0.856250\n", "2016-10-22 18:38:54,671 Epoch[1] Batch [50]\tSpeed: 140.84 samples/sec\tTrain-accuracy=0.888281\n", "2016-10-22 18:39:03,774 Epoch[1] Batch [60]\tSpeed: 140.62 samples/sec\tTrain-accuracy=0.891406\n", "2016-10-22 18:39:12,893 Epoch[1] Batch [70]\tSpeed: 140.38 samples/sec\tTrain-accuracy=0.893750\n", "2016-10-22 18:39:22,016 Epoch[1] Batch [80]\tSpeed: 140.33 samples/sec\tTrain-accuracy=0.911719\n", "2016-10-22 18:39:31,173 Epoch[1] Batch [90]\tSpeed: 139.79 samples/sec\tTrain-accuracy=0.893750\n", "2016-10-22 18:39:40,341 Epoch[1] Batch [100]\tSpeed: 139.65 samples/sec\tTrain-accuracy=0.885938\n", "2016-10-22 18:39:49,522 Epoch[1] Batch [110]\tSpeed: 139.45 samples/sec\tTrain-accuracy=0.901563\n", "2016-10-22 18:39:57,750 Epoch[1] Train-accuracy=0.907986\n", "2016-10-22 18:39:57,751 Epoch[1] Time cost=109.022\n", "2016-10-22 18:40:30,649 Epoch[1] Validation-accuracy=0.848608\n", "2016-10-22 18:40:40,134 Epoch[2] Batch [10]\tSpeed: 140.33 samples/sec\tTrain-accuracy=0.921875\n", "2016-10-22 18:40:49,247 Epoch[2] Batch [20]\tSpeed: 140.47 samples/sec\tTrain-accuracy=0.911719\n", "2016-10-22 18:40:58,367 Epoch[2] Batch [30]\tSpeed: 140.37 samples/sec\tTrain-accuracy=0.914844\n", "2016-10-22 18:41:07,515 Epoch[2] Batch [40]\tSpeed: 139.93 samples/sec\tTrain-accuracy=0.913281\n", "2016-10-22 18:41:16,659 Epoch[2] Batch [50]\tSpeed: 140.01 samples/sec\tTrain-accuracy=0.929688\n", "2016-10-22 18:41:25,826 Epoch[2] Batch [60]\tSpeed: 139.64 samples/sec\tTrain-accuracy=0.940625\n", "2016-10-22 18:41:35,015 Epoch[2] Batch [70]\tSpeed: 139.31 samples/sec\tTrain-accuracy=0.927344\n", "2016-10-22 18:41:44,178 Epoch[2] Batch [80]\tSpeed: 139.72 samples/sec\tTrain-accuracy=0.940625\n", "2016-10-22 18:41:53,316 Epoch[2] Batch [90]\tSpeed: 140.09 samples/sec\tTrain-accuracy=0.928125\n", "2016-10-22 18:42:02,413 Epoch[2] Batch [100]\tSpeed: 140.72 samples/sec\tTrain-accuracy=0.948438\n", "2016-10-22 18:42:11,522 Epoch[2] Batch [110]\tSpeed: 140.53 samples/sec\tTrain-accuracy=0.925781\n", "2016-10-22 18:42:20,624 Epoch[2] Batch [120]\tSpeed: 140.66 samples/sec\tTrain-accuracy=0.928906\n", "2016-10-22 18:42:20,625 Epoch[2] Train-accuracy=nan\n", "2016-10-22 18:42:20,626 Epoch[2] Time cost=109.976\n", "2016-10-22 18:42:53,414 Epoch[2] Validation-accuracy=0.853269\n", "2016-10-22 18:43:02,925 Epoch[3] Batch [10]\tSpeed: 139.86 samples/sec\tTrain-accuracy=0.941051\n", "2016-10-22 18:43:12,095 Epoch[3] Batch [20]\tSpeed: 139.60 samples/sec\tTrain-accuracy=0.935156\n", "2016-10-22 18:43:21,270 Epoch[3] Batch [30]\tSpeed: 139.52 samples/sec\tTrain-accuracy=0.939844\n", "2016-10-22 18:43:30,434 Epoch[3] Batch [40]\tSpeed: 139.70 samples/sec\tTrain-accuracy=0.945312\n", "2016-10-22 18:43:39,557 Epoch[3] Batch [50]\tSpeed: 140.31 samples/sec\tTrain-accuracy=0.946094\n", "2016-10-22 18:43:48,680 Epoch[3] Batch [60]\tSpeed: 140.33 samples/sec\tTrain-accuracy=0.937500\n", "2016-10-22 18:43:57,775 Epoch[3] Batch [70]\tSpeed: 140.75 samples/sec\tTrain-accuracy=0.951562\n", "2016-10-22 18:44:06,899 Epoch[3] Batch [80]\tSpeed: 140.31 samples/sec\tTrain-accuracy=0.956250\n", "2016-10-22 18:44:16,000 Epoch[3] Batch [90]\tSpeed: 140.67 samples/sec\tTrain-accuracy=0.942969\n", "2016-10-22 18:44:25,110 Epoch[3] Batch [100]\tSpeed: 140.52 samples/sec\tTrain-accuracy=0.958594\n", "2016-10-22 18:44:34,225 Epoch[3] Batch [110]\tSpeed: 140.46 samples/sec\tTrain-accuracy=0.946875\n", "2016-10-22 18:44:42,448 Epoch[3] Train-accuracy=0.952257\n", "2016-10-22 18:44:42,450 Epoch[3] Time cost=109.035\n", "2016-10-22 18:45:15,423 Epoch[3] Validation-accuracy=0.857587\n", "2016-10-22 18:45:24,921 Epoch[4] Batch [10]\tSpeed: 139.90 samples/sec\tTrain-accuracy=0.965199\n", "2016-10-22 18:45:34,041 Epoch[4] Batch [20]\tSpeed: 140.37 samples/sec\tTrain-accuracy=0.964844\n", "2016-10-22 18:45:43,172 Epoch[4] Batch [30]\tSpeed: 140.20 samples/sec\tTrain-accuracy=0.968750\n", "2016-10-22 18:45:52,287 Epoch[4] Batch [40]\tSpeed: 140.45 samples/sec\tTrain-accuracy=0.955469\n", "2016-10-22 18:46:01,418 Epoch[4] Batch [50]\tSpeed: 140.20 samples/sec\tTrain-accuracy=0.971094\n", "2016-10-22 18:46:10,534 Epoch[4] Batch [60]\tSpeed: 140.43 samples/sec\tTrain-accuracy=0.954688\n", "2016-10-22 18:46:19,664 Epoch[4] Batch [70]\tSpeed: 140.21 samples/sec\tTrain-accuracy=0.964063\n", "2016-10-22 18:46:28,811 Epoch[4] Batch [80]\tSpeed: 139.96 samples/sec\tTrain-accuracy=0.969531\n", "2016-10-22 18:46:37,986 Epoch[4] Batch [90]\tSpeed: 139.53 samples/sec\tTrain-accuracy=0.961719\n", "2016-10-22 18:46:47,150 Epoch[4] Batch [100]\tSpeed: 139.70 samples/sec\tTrain-accuracy=0.966406\n", "2016-10-22 18:46:56,307 Epoch[4] Batch [110]\tSpeed: 139.79 samples/sec\tTrain-accuracy=0.966406\n", "2016-10-22 18:47:05,456 Epoch[4] Batch [120]\tSpeed: 139.94 samples/sec\tTrain-accuracy=0.966406\n", "2016-10-22 18:47:05,457 Epoch[4] Train-accuracy=nan\n", "2016-10-22 18:47:05,457 Epoch[4] Time cost=110.033\n", "2016-10-22 18:47:38,303 Epoch[4] Validation-accuracy=0.862329\n", "2016-10-22 18:47:47,779 Epoch[5] Batch [10]\tSpeed: 140.25 samples/sec\tTrain-accuracy=0.971591\n", "2016-10-22 18:47:56,897 Epoch[5] Batch [20]\tSpeed: 140.40 samples/sec\tTrain-accuracy=0.970313\n", "2016-10-22 18:48:06,006 Epoch[5] Batch [30]\tSpeed: 140.53 samples/sec\tTrain-accuracy=0.976562\n", "2016-10-22 18:48:15,150 Epoch[5] Batch [40]\tSpeed: 140.01 samples/sec\tTrain-accuracy=0.967187\n", "2016-10-22 18:48:24,320 Epoch[5] Batch [50]\tSpeed: 139.60 samples/sec\tTrain-accuracy=0.975781\n", "2016-10-22 18:48:33,515 Epoch[5] Batch [60]\tSpeed: 139.22 samples/sec\tTrain-accuracy=0.971094\n", "2016-10-22 18:48:42,707 Epoch[5] Batch [70]\tSpeed: 139.26 samples/sec\tTrain-accuracy=0.971875\n", "2016-10-22 18:48:51,857 Epoch[5] Batch [80]\tSpeed: 139.92 samples/sec\tTrain-accuracy=0.988281\n", "2016-10-22 18:49:00,980 Epoch[5] Batch [90]\tSpeed: 140.32 samples/sec\tTrain-accuracy=0.969531\n", "2016-10-22 18:49:10,092 Epoch[5] Batch [100]\tSpeed: 140.49 samples/sec\tTrain-accuracy=0.984375\n", "2016-10-22 18:49:19,205 Epoch[5] Batch [110]\tSpeed: 140.49 samples/sec\tTrain-accuracy=0.978125\n", "2016-10-22 18:49:27,399 Epoch[5] Train-accuracy=0.968750\n", "2016-10-22 18:49:27,400 Epoch[5] Time cost=109.095\n", "2016-10-22 18:50:00,339 Epoch[5] Validation-accuracy=0.864102\n", "2016-10-22 18:50:09,861 Epoch[6] Batch [10]\tSpeed: 139.72 samples/sec\tTrain-accuracy=0.978693\n", "2016-10-22 18:50:19,028 Epoch[6] Batch [20]\tSpeed: 139.65 samples/sec\tTrain-accuracy=0.976562\n", "2016-10-22 18:50:28,206 Epoch[6] Batch [30]\tSpeed: 139.48 samples/sec\tTrain-accuracy=0.975000\n", "2016-10-22 18:50:37,343 Epoch[6] Batch [40]\tSpeed: 140.11 samples/sec\tTrain-accuracy=0.976562\n", "2016-10-22 18:50:46,475 Epoch[6] Batch [50]\tSpeed: 140.18 samples/sec\tTrain-accuracy=0.971094\n", "2016-10-22 18:50:55,613 Epoch[6] Batch [60]\tSpeed: 140.10 samples/sec\tTrain-accuracy=0.976562\n", "2016-10-22 18:51:04,717 Epoch[6] Batch [70]\tSpeed: 140.60 samples/sec\tTrain-accuracy=0.978906\n", "2016-10-22 18:51:13,821 Epoch[6] Batch [80]\tSpeed: 140.63 samples/sec\tTrain-accuracy=0.977344\n", "2016-10-22 18:51:22,932 Epoch[6] Batch [90]\tSpeed: 140.50 samples/sec\tTrain-accuracy=0.971875\n", "2016-10-22 18:51:32,039 Epoch[6] Batch [100]\tSpeed: 140.56 samples/sec\tTrain-accuracy=0.980469\n", "2016-10-22 18:51:41,172 Epoch[6] Batch [110]\tSpeed: 140.17 samples/sec\tTrain-accuracy=0.978906\n", "2016-10-22 18:51:50,312 Epoch[6] Batch [120]\tSpeed: 140.06 samples/sec\tTrain-accuracy=0.978906\n", "2016-10-22 18:51:50,314 Epoch[6] Train-accuracy=nan\n", "2016-10-22 18:51:50,314 Epoch[6] Time cost=109.974\n", "2016-10-22 18:52:23,287 Epoch[6] Validation-accuracy=0.864738\n", "2016-10-22 18:52:32,798 Epoch[7] Batch [10]\tSpeed: 139.84 samples/sec\tTrain-accuracy=0.982244\n", "2016-10-22 18:52:41,881 Epoch[7] Batch [20]\tSpeed: 140.94 samples/sec\tTrain-accuracy=0.980469\n", "2016-10-22 18:52:50,982 Epoch[7] Batch [30]\tSpeed: 140.67 samples/sec\tTrain-accuracy=0.978906\n", "2016-10-22 18:53:00,086 Epoch[7] Batch [40]\tSpeed: 140.61 samples/sec\tTrain-accuracy=0.980469\n", "2016-10-22 18:53:09,208 Epoch[7] Batch [50]\tSpeed: 140.35 samples/sec\tTrain-accuracy=0.975000\n", "2016-10-22 18:53:18,342 Epoch[7] Batch [60]\tSpeed: 140.15 samples/sec\tTrain-accuracy=0.970313\n", "2016-10-22 18:53:27,490 Epoch[7] Batch [70]\tSpeed: 139.94 samples/sec\tTrain-accuracy=0.978125\n", "2016-10-22 18:53:36,623 Epoch[7] Batch [80]\tSpeed: 140.15 samples/sec\tTrain-accuracy=0.989844\n", "2016-10-22 18:53:45,795 Epoch[7] Batch [90]\tSpeed: 139.58 samples/sec\tTrain-accuracy=0.976562\n", "2016-10-22 18:53:54,958 Epoch[7] Batch [100]\tSpeed: 139.70 samples/sec\tTrain-accuracy=0.981250\n", "2016-10-22 18:54:04,143 Epoch[7] Batch [110]\tSpeed: 139.39 samples/sec\tTrain-accuracy=0.974219\n", "2016-10-22 18:54:12,364 Epoch[7] Train-accuracy=0.976562\n", "2016-10-22 18:54:12,365 Epoch[7] Time cost=109.077\n", "2016-10-22 18:54:45,259 Epoch[7] Validation-accuracy=0.863905\n" ] } ], "source": [ "# @@@ AUTOTEST_OUTPUT_IGNORED_CELL\n", "get_model('http://data.mxnet.io/models/imagenet-11k/resnet-152/resnet-152', 0)\n", "sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)\n", "(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)\n", "mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)\n", "assert mod_score > 0.86, \"Low training accuracy.\"" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "As can be seen, even for a single data epoch, it reaches 83% validation accuracy. After 8 epoches, the validation accuracy increases to 86.4%. " ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 1 }