{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Enter State Farm" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using gpu device 1: GeForce GTX TITAN X (CNMeM is enabled with initial size: 90.0% of memory, cuDNN 4007)\n" ] } ], "source": [ "from theano.sandbox import cuda\n", "cuda.use('gpu1')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using Theano backend.\n" ] } ], "source": [ "%matplotlib inline\n", "from __future__ import print_function, division\n", "#path = \"data/state/\"\n", "path = \"data/state/sample/\"\n", "import utils; reload(utils)\n", "from utils import *\n", "from IPython.display import FileLink" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "batch_size=64" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Create sample" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "hidden": true }, "source": [ "The following assumes you've already created your validation set - remember that the training and validation set should contain *different drivers*, as mentioned on the Kaggle competition page." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%cd data/state" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%cd train" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "%mkdir ../sample\n", "%mkdir ../sample/train\n", "%mkdir ../sample/valid" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "for d in glob('c?'):\n", " os.mkdir('../sample/train/'+d)\n", " os.mkdir('../sample/valid/'+d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "from shutil import copyfile" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "g = glob('c?/*.jpg')\n", "shuf = np.random.permutation(g)\n", "for i in range(1500): copyfile(shuf[i], '../sample/train/' + shuf[i])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%cd ../valid" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "g = glob('c?/*.jpg')\n", "shuf = np.random.permutation(g)\n", "for i in range(1000): copyfile(shuf[i], '../sample/valid/' + shuf[i])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%cd ../../.." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%mkdir data/state/results" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "%mkdir data/state/sample/test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create batches" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n", "Found 1002 images belonging to 10 classes.\n" ] } ], "source": [ "batches = get_batches(path+'train', batch_size=batch_size)\n", "val_batches = get_batches(path+'valid', batch_size=batch_size*2, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n", "Found 1002 images belonging to 10 classes.\n", "Found 0 images belonging to 0 classes.\n" ] } ], "source": [ "(val_classes, trn_classes, val_labels, trn_labels, val_filenames, filenames,\n", " test_filename) = get_classes(path)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Basic models" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "### Linear model" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "First, we try the simplest model and use default parameters. Note the trick of making the first layer a batchnorm layer - that way we don't have to worry about normalizing the input ourselves." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "model = Sequential([\n", " BatchNormalization(axis=1, input_shape=(3,224,224)),\n", " Flatten(),\n", " Dense(10, activation='softmax')\n", " ])" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "As you can see below, this training is going nowhere..." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 20s - loss: 13.8189 - acc: 0.1040 - val_loss: 13.5792 - val_acc: 0.1517\n", "Epoch 2/2\n", "1568/1568 [==============================] - 5s - loss: 14.4052 - acc: 0.1052 - val_loss: 13.8349 - val_acc: 0.1397\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.compile(Adam(), loss='categorical_crossentropy', metrics=['accuracy'])\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=2, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Let's first check the number of parameters to see that there's enough parameters to find some useful relationships:" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "____________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "====================================================================================================\n", "batchnormalization_65 (BatchNorma(None, 3, 224, 224) 6 batchnormalization_input_23[0][0]\n", "____________________________________________________________________________________________________\n", "flatten_23 (Flatten) (None, 150528) 0 batchnormalization_65[0][0] \n", "____________________________________________________________________________________________________\n", "dense_39 (Dense) (None, 10) 1505290 flatten_23[0][0] \n", "====================================================================================================\n", "Total params: 1505296\n", "____________________________________________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Over 1.5 million parameters - that should be enough. Incidentally, it's worth checking you understand why this is the number of parameters in this layer:" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "150528" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "10*3*224*224" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Since we have a simple model with no regularization and plenty of parameters, it seems most likely that our learning rate is too high. Perhaps it is jumping to a solution where it predicts one or two classes with high confidence, so that it can give a zero prediction to as many classes as possible - that's the best approach for a model that is no better than random, and there is likely to be where we would end up with a high learning rate. So let's check:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=float32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.round(model.predict_generator(batches, batches.N)[:10],2)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Our hypothesis was correct. It's nearly always predicting class 1 or 6, with very high confidence. So let's try a lower learning rate:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 7s - loss: 2.4180 - acc: 0.1575 - val_loss: 5.2975 - val_acc: 0.1477\n", "Epoch 2/2\n", "1568/1568 [==============================] - 5s - loss: 1.7690 - acc: 0.4196 - val_loss: 4.0165 - val_acc: 0.1926\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Sequential([\n", " BatchNormalization(axis=1, input_shape=(3,224,224)),\n", " Flatten(),\n", " Dense(10, activation='softmax')\n", " ])\n", "model.compile(Adam(lr=1e-5), loss='categorical_crossentropy', metrics=['accuracy'])\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=2, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Great - we found our way out of that hole... Now we can increase the learning rate and see where we can get to." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "model.optimizer.lr=0.001" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/4\n", "1568/1568 [==============================] - 7s - loss: 1.3763 - acc: 0.5816 - val_loss: 2.5994 - val_acc: 0.2884\n", "Epoch 2/4\n", "1568/1568 [==============================] - 5s - loss: 1.0961 - acc: 0.7136 - val_loss: 1.9945 - val_acc: 0.3902\n", "Epoch 3/4\n", "1568/1568 [==============================] - 5s - loss: 0.9395 - acc: 0.7730 - val_loss: 1.9828 - val_acc: 0.3822\n", "Epoch 4/4\n", "1568/1568 [==============================] - 5s - loss: 0.7894 - acc: 0.8323 - val_loss: 1.8041 - val_acc: 0.3962\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit_generator(batches, batches.nb_sample, nb_epoch=4, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We're stabilizing at validation accuracy of 0.39. Not great, but a lot better than random. Before moving on, let's check that our validation set on the sample is large enough that it gives consistent results:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1002 images belonging to 10 classes.\n" ] } ], "source": [ "rnd_batches = get_batches(path+'valid', batch_size=batch_size*2, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 4.4 , 0.49],\n", " [ 4.57, 0.49],\n", " [ 4.48, 0.48],\n", " [ 4.28, 0.51],\n", " [ 4.66, 0.48],\n", " [ 4.5 , 0.49],\n", " [ 4.46, 0.49],\n", " [ 4.51, 0.47],\n", " [ 4.45, 0.51],\n", " [ 4.47, 0.49]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_res = [model.evaluate_generator(rnd_batches, rnd_batches.nb_sample) for i in range(10)]\n", "np.round(val_res, 2)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Yup, pretty consistent - if we see improvements of 3% or more, it's probably not random, based on the above samples." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "### L2 regularization" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The previous model is over-fitting a lot, but we can't use dropout since we only have one layer. We can try to decrease overfitting in our model by adding [l2 regularization](http://www.kdnuggets.com/2015/04/preventing-overfitting-neural-networks.html/2) (i.e. add the sum of squares of the weights to our loss function):" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 7s - loss: 5.7173 - acc: 0.2583 - val_loss: 14.5162 - val_acc: 0.0988\n", "Epoch 2/2\n", "1568/1568 [==============================] - 5s - loss: 2.5953 - acc: 0.6148 - val_loss: 4.8340 - val_acc: 0.3952\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Sequential([\n", " BatchNormalization(axis=1, input_shape=(3,224,224)),\n", " Flatten(),\n", " Dense(10, activation='softmax', W_regularizer=l2(0.01))\n", " ])\n", "model.compile(Adam(lr=10e-5), loss='categorical_crossentropy', metrics=['accuracy'])\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=2, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "model.optimizer.lr=0.001" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/4\n", "1568/1568 [==============================] - 7s - loss: 1.5759 - acc: 0.8355 - val_loss: 4.3326 - val_acc: 0.3902\n", "Epoch 2/4\n", "1568/1568 [==============================] - 5s - loss: 0.9414 - acc: 0.8552 - val_loss: 3.5898 - val_acc: 0.3872\n", "Epoch 3/4\n", "1568/1568 [==============================] - 5s - loss: 0.4152 - acc: 0.9401 - val_loss: 2.3976 - val_acc: 0.4780\n", "Epoch 4/4\n", "1568/1568 [==============================] - 5s - loss: 0.3282 - acc: 0.9726 - val_loss: 2.3441 - val_acc: 0.5100\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit_generator(batches, batches.nb_sample, nb_epoch=4, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "hidden": true }, "source": [ "Looks like we can get a bit over 50% accuracy this way. This will be a good benchmark for our future models - if we can't beat 50%, then we're not even beating a linear model trained on a sample, so we'll know that's not a good approach." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "### Single hidden layer" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The next simplest model is to add a single hidden layer." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 7s - loss: 2.0182 - acc: 0.3412 - val_loss: 3.4769 - val_acc: 0.2435\n", "Epoch 2/2\n", "1568/1568 [==============================] - 5s - loss: 1.0104 - acc: 0.7379 - val_loss: 2.2270 - val_acc: 0.4361\n", "Epoch 1/5\n", "1568/1568 [==============================] - 7s - loss: 0.5350 - acc: 0.9043 - val_loss: 1.8474 - val_acc: 0.4621\n", "Epoch 2/5\n", "1568/1568 [==============================] - 5s - loss: 0.3459 - acc: 0.9458 - val_loss: 1.9591 - val_acc: 0.4222\n", "Epoch 3/5\n", "1568/1568 [==============================] - 5s - loss: 0.2296 - acc: 0.9802 - val_loss: 1.7887 - val_acc: 0.4441\n", "Epoch 4/5\n", "1568/1568 [==============================] - 5s - loss: 0.1591 - acc: 0.9936 - val_loss: 1.6847 - val_acc: 0.4830\n", "Epoch 5/5\n", "1568/1568 [==============================] - 5s - loss: 0.1204 - acc: 0.9943 - val_loss: 1.6344 - val_acc: 0.4910\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Sequential([\n", " BatchNormalization(axis=1, input_shape=(3,224,224)),\n", " Flatten(),\n", " Dense(100, activation='relu'),\n", " BatchNormalization(),\n", " Dense(10, activation='softmax')\n", " ])\n", "model.compile(Adam(lr=1e-5), loss='categorical_crossentropy', metrics=['accuracy'])\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=2, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)\n", "\n", "model.optimizer.lr = 0.01\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=5, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Not looking very encouraging... which isn't surprising since we know that CNNs are a much better choice for computer vision problems. So we'll try one." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true, "hidden": true }, "source": [ "### Single conv layer" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "2 conv layers with max pooling followed by a simple dense network is a good simple CNN to start with:" ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "collapsed": false, "hidden": true }, "outputs": [], "source": [ "def conv1(batches):\n", " model = Sequential([\n", " BatchNormalization(axis=1, input_shape=(3,224,224)),\n", " Convolution2D(32,3,3, activation='relu'),\n", " BatchNormalization(axis=1),\n", " MaxPooling2D((3,3)),\n", " Convolution2D(64,3,3, activation='relu'),\n", " BatchNormalization(axis=1),\n", " MaxPooling2D((3,3)),\n", " Flatten(),\n", " Dense(200, activation='relu'),\n", " BatchNormalization(),\n", " Dense(10, activation='softmax')\n", " ])\n", "\n", " model.compile(Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])\n", " model.fit_generator(batches, batches.nb_sample, nb_epoch=2, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)\n", " model.optimizer.lr = 0.001\n", " model.fit_generator(batches, batches.nb_sample, nb_epoch=4, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)\n", " return model" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.3664 - acc: 0.6020 - val_loss: 1.8697 - val_acc: 0.3932\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 0.3201 - acc: 0.9388 - val_loss: 2.2294 - val_acc: 0.1537\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.0862 - acc: 0.9911 - val_loss: 2.5230 - val_acc: 0.1517\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.0350 - acc: 0.9994 - val_loss: 2.8057 - val_acc: 0.1497\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.0201 - acc: 1.0000 - val_loss: 2.9036 - val_acc: 0.1607\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.0124 - acc: 1.0000 - val_loss: 2.9390 - val_acc: 0.1647\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The training set here is very rapidly reaching a very high accuracy. So if we could regularize this, perhaps we could get a reasonable result.\n", "\n", "So, what kind of regularization should we try first? As we discussed in lesson 3, we should start with data augmentation." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "To find the best data augmentation parameters, we can try each type of data augmentation, one at a time. For each type, we can try four very different levels of augmentation, and see which is the best. In the steps below we've only kept the single best result we found. We're using the CNN we defined above, since we have already observed it can model the data quickly and accurately." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Width shift: move the image left and right -" ] }, { "cell_type": "code", "execution_count": 63, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(width_shift_range=0.1)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 2.1802 - acc: 0.3316 - val_loss: 2.9037 - val_acc: 0.1038\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 1.0996 - acc: 0.6862 - val_loss: 2.1270 - val_acc: 0.2495\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.6856 - acc: 0.8106 - val_loss: 2.1610 - val_acc: 0.1487\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.4989 - acc: 0.8693 - val_loss: 2.0959 - val_acc: 0.2525\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.3715 - acc: 0.9120 - val_loss: 2.1168 - val_acc: 0.2385\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.2916 - acc: 0.9254 - val_loss: 2.1028 - val_acc: 0.3044\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Height shift: move the image up and down -" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(height_shift_range=0.05)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.7843 - acc: 0.4458 - val_loss: 2.1259 - val_acc: 0.2375\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 0.7028 - acc: 0.7825 - val_loss: 2.0232 - val_acc: 0.3164\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.3586 - acc: 0.9152 - val_loss: 2.1772 - val_acc: 0.1806\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.2335 - acc: 0.9490 - val_loss: 2.1935 - val_acc: 0.1727\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.1626 - acc: 0.9656 - val_loss: 2.1944 - val_acc: 0.2106\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.1214 - acc: 0.9758 - val_loss: 2.3481 - val_acc: 0.1766\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Random shear angles (max in radians) -" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(shear_range=0.1)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.6148 - acc: 0.5223 - val_loss: 2.2513 - val_acc: 0.2475\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 0.3915 - acc: 0.9203 - val_loss: 2.0757 - val_acc: 0.2725\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.1478 - acc: 0.9821 - val_loss: 2.1869 - val_acc: 0.3084\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.0831 - acc: 0.9904 - val_loss: 2.2449 - val_acc: 0.3164\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.0530 - acc: 0.9955 - val_loss: 2.2426 - val_acc: 0.3154\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.0343 - acc: 0.9994 - val_loss: 2.2609 - val_acc: 0.3234\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Rotation: max in degrees -" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(rotation_range=15)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.9734 - acc: 0.3865 - val_loss: 2.1849 - val_acc: 0.3064\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 0.8523 - acc: 0.7411 - val_loss: 2.0310 - val_acc: 0.2655\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.4652 - acc: 0.8833 - val_loss: 2.0401 - val_acc: 0.2036\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.3448 - acc: 0.9101 - val_loss: 2.2149 - val_acc: 0.1317\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.2411 - acc: 0.9420 - val_loss: 2.2614 - val_acc: 0.1287\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.1722 - acc: 0.9636 - val_loss: 2.1208 - val_acc: 0.2106\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Channel shift: randomly changing the R,G,B colors - " ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(channel_shift_range=20)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.6381 - acc: 0.5191 - val_loss: 2.2146 - val_acc: 0.3483\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 0.3530 - acc: 0.9305 - val_loss: 2.0966 - val_acc: 0.2665\n", "Epoch 1/4\n", "1568/1568 [==============================] - 11s - loss: 0.1036 - acc: 0.9923 - val_loss: 2.4195 - val_acc: 0.1766\n", "Epoch 2/4\n", "1568/1568 [==============================] - 11s - loss: 0.0450 - acc: 1.0000 - val_loss: 2.6192 - val_acc: 0.1667\n", "Epoch 3/4\n", "1568/1568 [==============================] - 11s - loss: 0.0259 - acc: 0.9994 - val_loss: 2.7227 - val_acc: 0.1816\n", "Epoch 4/4\n", "1568/1568 [==============================] - 11s - loss: 0.0180 - acc: 0.9994 - val_loss: 2.7049 - val_acc: 0.2206\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "And finally, putting it all together!" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1568 images belonging to 10 classes.\n" ] } ], "source": [ "gen_t = image.ImageDataGenerator(rotation_range=15, height_shift_range=0.05, \n", " shear_range=0.1, channel_shift_range=20, width_shift_range=0.1)\n", "batches = get_batches(path+'train', gen_t, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n", "1568/1568 [==============================] - 12s - loss: 2.4533 - acc: 0.2258 - val_loss: 2.1042 - val_acc: 0.2265\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 1.7107 - acc: 0.4305 - val_loss: 2.1321 - val_acc: 0.2295\n", "Epoch 1/2\n", "1568/1568 [==============================] - 11s - loss: 1.4329 - acc: 0.5478 - val_loss: 2.3451 - val_acc: 0.1427\n", "Epoch 2/2\n", "1568/1568 [==============================] - 11s - loss: 1.2623 - acc: 0.5918 - val_loss: 2.4122 - val_acc: 0.1088\n" ] } ], "source": [ "model = conv1(batches)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "At first glance, this isn't looking encouraging, since the validation set is poor and getting worse. But the training set is getting better, and still has a long way to go in accuracy - so we should try annealing our learning rate and running more epochs, before we make a decisions." ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "collapsed": false, "hidden": true, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "1568/1568 [==============================] - 11s - loss: 1.1570 - acc: 0.6282 - val_loss: 2.4787 - val_acc: 0.1048\n", "Epoch 2/5\n", "1568/1568 [==============================] - 11s - loss: 1.0278 - acc: 0.6582 - val_loss: 2.4211 - val_acc: 0.1267\n", "Epoch 3/5\n", "1568/1568 [==============================] - 11s - loss: 0.9459 - acc: 0.6939 - val_loss: 2.5656 - val_acc: 0.1477\n", "Epoch 4/5\n", "1568/1568 [==============================] - 11s - loss: 0.9045 - acc: 0.6996 - val_loss: 2.2994 - val_acc: 0.2365\n", "Epoch 5/5\n", "1568/1568 [==============================] - 11s - loss: 0.8346 - acc: 0.7360 - val_loss: 2.1203 - val_acc: 0.2705\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.optimizer.lr = 0.0001\n", "model.fit_generator(batches, batches.nb_sample, nb_epoch=5, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Lucky we tried that - we starting to make progress! Let's keep going." ] }, { "cell_type": "code", "execution_count": 61, "metadata": { "collapsed": false, "hidden": true, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/25\n", "1568/1568 [==============================] - 11s - loss: 0.8055 - acc: 0.7423 - val_loss: 2.0895 - val_acc: 0.2984\n", "Epoch 2/25\n", "1568/1568 [==============================] - 11s - loss: 0.7538 - acc: 0.7621 - val_loss: 1.8985 - val_acc: 0.4212\n", "Epoch 3/25\n", "1568/1568 [==============================] - 11s - loss: 0.7037 - acc: 0.7774 - val_loss: 1.7200 - val_acc: 0.4411\n", "Epoch 4/25\n", "1568/1568 [==============================] - 11s - loss: 0.6865 - acc: 0.7966 - val_loss: 1.5225 - val_acc: 0.5180\n", "Epoch 5/25\n", "1568/1568 [==============================] - 11s - loss: 0.6404 - acc: 0.8036 - val_loss: 1.3924 - val_acc: 0.5319\n", "Epoch 6/25\n", "1568/1568 [==============================] - 11s - loss: 0.6116 - acc: 0.8144 - val_loss: 1.4472 - val_acc: 0.5259\n", "Epoch 7/25\n", "1568/1568 [==============================] - 11s - loss: 0.5671 - acc: 0.8361 - val_loss: 1.4703 - val_acc: 0.5549\n", "Epoch 8/25\n", "1568/1568 [==============================] - 11s - loss: 0.5559 - acc: 0.8265 - val_loss: 1.2402 - val_acc: 0.6337\n", "Epoch 9/25\n", "1568/1568 [==============================] - 11s - loss: 0.5434 - acc: 0.8406 - val_loss: 1.2765 - val_acc: 0.6297\n", "Epoch 10/25\n", "1568/1568 [==============================] - 11s - loss: 0.4877 - acc: 0.8533 - val_loss: 1.2366 - val_acc: 0.6267\n", "Epoch 11/25\n", "1568/1568 [==============================] - 11s - loss: 0.4944 - acc: 0.8406 - val_loss: 1.3992 - val_acc: 0.5349\n", "Epoch 12/25\n", "1568/1568 [==============================] - 11s - loss: 0.4694 - acc: 0.8597 - val_loss: 1.1821 - val_acc: 0.6277\n", "Epoch 13/25\n", "1568/1568 [==============================] - 11s - loss: 0.4251 - acc: 0.8858 - val_loss: 1.1803 - val_acc: 0.6427\n", "Epoch 14/25\n", "1568/1568 [==============================] - 11s - loss: 0.4501 - acc: 0.8680 - val_loss: 1.2752 - val_acc: 0.5908\n", "Epoch 15/25\n", "1568/1568 [==============================] - 11s - loss: 0.3922 - acc: 0.8846 - val_loss: 1.1758 - val_acc: 0.6457\n", "Epoch 16/25\n", "1568/1568 [==============================] - 11s - loss: 0.4406 - acc: 0.8629 - val_loss: 1.3147 - val_acc: 0.5808\n", "Epoch 17/25\n", "1568/1568 [==============================] - 11s - loss: 0.4075 - acc: 0.8788 - val_loss: 1.2941 - val_acc: 0.6148\n", "Epoch 18/25\n", "1568/1568 [==============================] - 11s - loss: 0.3890 - acc: 0.8948 - val_loss: 1.1871 - val_acc: 0.6567\n", "Epoch 19/25\n", "1568/1568 [==============================] - 11s - loss: 0.3708 - acc: 0.8890 - val_loss: 1.1560 - val_acc: 0.6756\n", "Epoch 20/25\n", "1568/1568 [==============================] - 11s - loss: 0.3539 - acc: 0.8973 - val_loss: 1.2621 - val_acc: 0.6537\n", "Epoch 21/25\n", "1568/1568 [==============================] - 11s - loss: 0.3582 - acc: 0.8909 - val_loss: 1.1357 - val_acc: 0.6677\n", "Epoch 22/25\n", "1568/1568 [==============================] - 11s - loss: 0.3232 - acc: 0.9056 - val_loss: 1.2114 - val_acc: 0.6287\n", "Epoch 23/25\n", "1568/1568 [==============================] - 11s - loss: 0.3286 - acc: 0.9011 - val_loss: 1.2917 - val_acc: 0.6377\n", "Epoch 24/25\n", "1568/1568 [==============================] - 11s - loss: 0.3080 - acc: 0.9139 - val_loss: 1.2519 - val_acc: 0.6248\n", "Epoch 25/25\n", "1568/1568 [==============================] - 11s - loss: 0.2999 - acc: 0.9152 - val_loss: 1.1980 - val_acc: 0.6647\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit_generator(batches, batches.nb_sample, nb_epoch=25, validation_data=val_batches, \n", " nb_val_samples=val_batches.nb_sample)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "hidden": true }, "source": [ "Amazingly, using nothing but a small sample, a simple (not pre-trained) model with no dropout, and data augmentation, we're getting results that would get us into the top 50% of the competition! This looks like a great foundation for our futher experiments.\n", "\n", "To go further, we'll need to use the whole dataset, since dropout and data volumes are very related, so we can't tweak dropout without using all the data." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" }, "nav_menu": {}, "nbpresent": { "slides": { "28b43202-5690-4169-9aca-6b9dabfeb3ec": { "id": "28b43202-5690-4169-9aca-6b9dabfeb3ec", "prev": null, "regions": { "3bba644a-cf4d-4a49-9fbd-e2554428cf9f": { "attrs": { "height": 0.8, "width": 0.8, "x": 0.1, "y": 0.1 }, "content": { "cell": "f3d3a388-7e2a-4151-9b50-c20498fceacc", "part": "whole" }, "id": "3bba644a-cf4d-4a49-9fbd-e2554428cf9f" } } }, "8104def2-4b68-44a0-8f1b-b03bf3b2a079": { "id": "8104def2-4b68-44a0-8f1b-b03bf3b2a079", "prev": "28b43202-5690-4169-9aca-6b9dabfeb3ec", "regions": { "7dded777-1ddf-4100-99ae-25cf1c15b575": { "attrs": { "height": 0.8, "width": 0.8, "x": 0.1, "y": 0.1 }, "content": { "cell": "fe47bd48-3414-4657-92e7-8b8d6cb0df00", "part": "whole" }, "id": "7dded777-1ddf-4100-99ae-25cf1c15b575" } } } }, "themes": {} }, "toc": { "nav_menu": { "height": "148px", "width": "254px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false }, "widgets": { "state": {}, "version": "1.1.2" } }, "nbformat": 4, "nbformat_minor": 0 }