{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import importlib\n",
"import utils2; importlib.reload(utils2)\n",
"from utils2 import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"limit_mem()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras.datasets.cifar10 import load_batch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook contains a Keras implementation of Huang et al.'s [DenseNet](https://arxiv.org/abs/1608.06993)\n",
"\n",
"Our motivation behind studying DenseNet is because of how well it works with limited data.\n",
"\n",
"DenseNet beats state-of-the-art results on CIFAR-10/CIFAR-100 w/ and w/o data augmentation, but the performance increase is most pronounced w/o data augmentation.\n",
"\n",
"Compare to FractalNet, state-of-the-art on both datasets:\n",
"* CIFAR-10: ~ 30 % performance increase w/ DenseNet\n",
"* CIFAR-100: ~ 30 % performance increase w/ DenseNet\n",
"\n",
"That increase is motivation enough."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So what is a DenseNet?\n",
"\n",
"Put simply, DenseNet is a Resnet where we replace addition with concatenation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Idea"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Recall that in broad terms, a Resnet is a Convnet that uses residual block structures.\n",
"\n",
"These \"blocks\" work as follows:\n",
"* Let Lt be the input layer to block\n",
"* Perform conv layer transformations/activations on Lt, denote by f(t)\n",
"* Call output layer of block Lt+1\n",
"* Define Lt+1 = f(Lt)+ Lt \n",
" * That is, total output is the conv layer outputs plus the original input\n",
"* We call residual block b.c. f(Lt)=Lt+1 - Lt, the residual\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As mentioned, the difference w/ DenseNet is instead of adding Lt to Lt+1, it is being concatenated.\n",
"\n",
"As with Resnet, DenseNet consists of multiple blocks.\n",
"Therefore, there is a recursive relationship across blocks:\n",
"* Block Bi takes as input the ouput of block Bi-1 concatenated with the input of Bi-1\n",
"* The input to Bi-1 is the ouput of block Bi-2 concatenated with the input of Bi-2\n",
"* So on and so forth"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The number of filters added to each layer needs to be monitored, given that the input space for each block keeps growing.\n",
"\n",
"Huang et al. calls the # of filters added at each layer the *growth rate*, and appropriately denotes this number with the related letter *k*."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Densenet / CIFAR 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load data."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_data():\n",
" path = 'data/cifar-10-batches-py'\n",
" num_train_samples = 50000\n",
" x_train = np.zeros((num_train_samples, 3, 32, 32), dtype='uint8')\n",
" y_train = np.zeros((num_train_samples,), dtype='uint8')\n",
" for i in range(1, 6):\n",
" data, labels = load_batch(os.path.join(path, 'data_batch_' + str(i)))\n",
" x_train[(i - 1) * 10000: i * 10000, :, :, :] = data\n",
" y_train[(i - 1) * 10000: i * 10000] = labels\n",
" x_test, y_test = load_batch(os.path.join(path, 'test_batch'))\n",
" y_train = np.reshape(y_train, (len(y_train), 1))\n",
" y_test = np.reshape(y_test, (len(y_test), 1))\n",
" x_train = x_train.transpose(0, 2, 3, 1)\n",
" x_test = x_test.transpose(0, 2, 3, 1)\n",
" return (x_train, y_train), (x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's an example of CIFAR-10"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAH/dJREFUeJztnVuQnWeVnt+1T30+t7rVklpqSZaEZNmWjVBs7BgSD9gQ\nUoaaxAUXE19Q47kgJFQmFy6mKpA7kgpMcZFQZYJrzIRwqAEGl2EyMcaDYXxCPulg2bKs86FbUkut\n3Yd93isXvV0ly9/7dcuSdsv536dKpe7v7W//X//7X/vv/b17rWXuDiFE8kgt9QKEEEuDgl+IhKLg\nFyKhKPiFSCgKfiESioJfiISi4BcioSj4hUgoCn4hEkrmSiab2X0Avg0gDeB/uvs3Yj/f1d3jA0PD\nQa1cnKPzquVicNzd6JxsrpVquRaupbM5qqVS4eMVCzN0TrlUoJrXalQz8N8tlU7zeanw63lHZxed\n0xI5H16rUq1Q4M8ZEP7kaN3rdEaxwM9VLbKO2KdUmVSt8nXU67HH4/MyGR5OmQx/zhzh6yD24ds6\nWUZhroBSqcwvnovXtJgfCmFmaQD/HcAnABwH8Acze9zdX2dzBoaG8Rff+h9B7fgbL9FjnTm0Lzhe\nq/HlD6/+ENVWr99Mtb7lq6nW2hY+3v69z9I5Rw7solplmr9opCO/W3dfD9Uyre3B8R133k3n3LCR\nn6vihXNU27vnFarV6+XgeLkSfiEHgNf37qZafuos1UrlEtUq5XDQnZvkL1wzc3yN1Ro/1rJl/VTr\n6++kWs2nw8eq0CkoFsKvDP/w9PN80iVcyZ/9OwAccPeD7l4G8CMA91/B4wkhmsiVBP9KAMcu+v54\nY0wI8QHgmm/4mdlDZrbTzHZO5y9c68MJIRbJlQT/CQCjF32/qjH2Ltz9EXff7u7bu7r5e1UhRHO5\nkuD/A4ANZrbWzHIAPg/g8auzLCHEteZ97/a7e9XM/i2Av8e81feou++NzanVasifD+8eD/TynVJf\nFrYHPdNN54ysXsfXUefbqKk63wWuz4XtpuL5STrHC3zneOXgENVWj95AtdEb1lBtxcpVwfEhYrEC\nQDbbQrVqb9g9AIDRVcv5vGp4t79Y5Hbe1Hnufpw9y12HTMTWhYV3+/sG+O/c2sHXeCF/nmotrTyc\n6s6tymwmvJb8hSk6p1wK7/Y78wADXJHP7+6/AvCrK3kMIcTSoE/4CZFQFPxCJBQFvxAJRcEvREJR\n8AuRUK5ot/+ycQcqYZutXOL229xc2DYa28g/TTwzO0u1WHJJ/2AkaSYbfq3csGEjnfPR27dTbeVw\n2JYDgJ6eZVSrZHg2YHtr2DbKRDLErBrJ3Jvl9luJPJcA0N4Wtgj7erm9uX7dFqrt2/cm1WB8HaVS\n2Lrt6e6jcyKJnbiQn6CaI3ydAvFMwfPnw9dqYY4nEbGMv8vpw6E7vxAJRcEvREJR8AuRUBT8QiQU\nBb8QCaWpu/1er6NKEjusynewW3JtwfELZ3lpp4HlfCd99Y08aWZodAXVsmwbOFJvqVLlzsIbp3hC\n0NzBM/wxU3xX+c3drwXHP7KZ76TfveMjVIvtHucj9RmOHjkZHM9lI7UVczxRa3AZd3aOHnuLPyYp\nazZT4G5QPs+vq0yWl8fr7uZJULF6h6w8YazOYEtL+Fq0RVXvm0d3fiESioJfiISi4BcioSj4hUgo\nCn4hEoqCX4iE0nSrrzQXtlg627gF1N0fTnK57ZZtdM7oug1Um44ksrx58BjV8nNhu2Zmitdam5zi\ndt6pcV4PrjuS2IMUT/h44sc/DY5nH+Cv8x+74y6qZbPcxly+nNui8LBdNnU+3J0GAF5+hXc3ykTq\nDHZ0cYuwWgtbleUZ/pylI7fEWFeeWo1bsJPnuH2YQtgijLX/6u0NJ6ClI23B3ntcIUQiUfALkVAU\n/EIkFAW/EAlFwS9EQlHwC5FQrsjqM7PDAKYB1ABU3Z0XrANgKUNLSzaoVdJddF6hrTM4fijP2yq9\n+vsXqXZuktelO3GS12jLpsMpU9kUz74qkbZVAFAscm1kGX9qTo8foVo3yfaansrTOfsPHeLrGBmk\nWjbL1zgyGm7ltYKMA8DRcW6zvrmba0Mj3BY9fJRYbBX+nNXLXKtF6ie25rgd2ZIJX/cAUCiGH7O7\nm1uYGdLiyy7jfn41fP5/5k5MXSHEdYv+7BcioVxp8DuAX5vZS2b20NVYkBCiOVzpn/13ufsJMxsC\n8KSZveHuz1z8A40XhYcAoLePfzRSCNFcrujO7+4nGv+fBvBzADsCP/OIu2939+0dneGNOyFE83nf\nwW9mHWbW9c7XAD4JYM/VWpgQ4tpyJX/2DwP4uc1XDMwA+N/u/n9iE1KpDNrbh4Pa6SmeaXfgWNjm\neX0vf61JRWyoWqQ1WGGaF3ZME0uvUOI22tQ016YjrbAOH99HtY42botuWr8pLEQsx3/83T9Qbc3a\ntVTbuIm3KRsYCGedtbTy56Wnm1tlqSovFjpb4vcw1vKqMMWzC2s1XnS1tY1bdjN5/pjdkczDltZw\nJl65HGthF84wrde5TXkp7zv43f0ggFve73whxNIiq0+IhKLgFyKhKPiFSCgKfiESioJfiITS1AKe\n6XQGvf3hLLEDx/bTeacOh7PO2rO8kOWFWV4ccyZ/mmoWsUqmpsPW3FSBW0MZksUIAIPDQ1Rr6wpb\nZQCwcoybLKPENjr02nN0Ttq4DVip8Sy2M2d5cdKbbtocHL9hwzo6ZzSSndd5+61U2/XGUaqViuHC\nsKVsJKsP3JarO7ekx8fD/QkBINfCbcyePnYdcNu5UAhntNZ98Vaf7vxCJBQFvxAJRcEvREJR8AuR\nUBT8QiSUpu72l0qzePvtcG29N94+QOedPPV2cLwWScLp6umg2qYNY1Tbunkr1U6dCe+wHjnD17Fs\neTiRCQDWrOdJM10D3AmYOM+P52fDzsjRI3xH/EykpdjmLVTCJzaGd/QBYHaG7EZz8wBe5q7D3ue5\nW7FhE2/bNryyNzj+/IvPBMcBYHyCJ2NVKny3v1jg6z8faVPW1hleY2znfpa0vbucxB7d+YVIKAp+\nIRKKgl+IhKLgFyKhKPiFSCgKfiESSlOtvtmZPJ5/5snwQoZJ7TkA6zffFBxvi7RV2rxlA9U2bVxF\ntVoxnBgDAJ4K21ez4A2LMtlwYgkApNNhiwcAKlWeCDI7fY5qPeWwFVWtOZ1z9DRPgmrtPMGP1d1H\ntXXrx4LjHrnfFKbCdekA4I0XXqWaF/h1sPXe+4LjN93ME4wKO7nV9/aBw1Rrb+fVqXt6B6g23+3u\nveTz/HkplcLnymX1CSEWQsEvREJR8AuRUBT8QiQUBb8QCUXBL0RCWdDqM7NHAXwGwGl339oY6wfw\nYwBjAA4DeMDduS/RoFKu4vSxsC126y3/gs5raQnXduvnrhxGVvA6bOcirZqOHeA2Wrkett9SxlPV\n0hluvdSc1yBENdZuLGw5AoDXwsfr7AnXTgSAyRmeJZjK8ezIunP7cL57e2gSn9HZyp+zsRWjVGtN\n83WkEK67eNNWnlHZ28st2McL/5dq46d4CKwcWkG1moVrQGYjLefy+bAduS8bbm0XYjF3/r8CcKlZ\n+jCAp9x9A4CnGt8LIT5ALBj87v4MgEtvh/cDeKzx9WMAPnuV1yWEuMa83/f8w+5+qvH1OOY79goh\nPkBc8cd73d3NjL7pMrOHADwEANksr2EvhGgu7/fOP2FmIwDQ+J92wXD3R9x9u7tvz2SamkoghIjw\nfoP/cQAPNr5+EMAvrs5yhBDNYjFW3w8BfBzAoJkdB/A1AN8A8BMz+yKAIwAeWMzBUqkM2jv7g1o2\n4hpNTYX/sGjp55bMXJV7SkXeXQttfV1Ua6kbeUBu9XnkDBcrPIuttY1PTEXaa9VT4XmdA9xqyjm3\nN9NtPHPPc9xrrVv4d7Matw5Taf47ZztyVGvr5Fq1FLZ1J09M0DkDHbxt2P2fvpdqO187TLWZSHHP\nYulMcLxEWnIBQG9X+NrPpCP+96U/u9APuPsXiHTPoo8ihLju0Cf8hEgoCn4hEoqCX4iEouAXIqEo\n+IVIKE391E0u14KR1eFsKkvx16FiMZzBNJHny8/18iy2SpVbQxb5FGJhJpwhVnG+9kyGF+KsprnW\n3s0z3IYGpqjm58L2UDnSY87qfP1tbW1US0VcpbqHj1ercVs0lY0UT03zNc7M8ixNIwUtWyLXW/4M\ntwHb2sNWNQDcfcfNVHvz7SNU2/P6eHB8Js+zLXOkMGy9Hsu0fDe68wuRUBT8QiQUBb8QCUXBL0RC\nUfALkVAU/EIklKZafW6AW9jOqUSsqLnpsJXTErGhpvORQpxFXjhzLs9toyxJ6uvq4Jbdsj5uDXX3\n8wy3Zb38d6tleqhWaAmfx3NreFZfqXaKaohkHtaqkexCkgFZS/FsS4tYfb39PLuwXouskVxXPT38\n/OZ4bRpMTUds1krYCgaAbZuXU623K3z9PPEELxZ6ZiJcCLcaiaNL0Z1fiISi4BcioSj4hUgoCn4h\nEoqCX4iE0txyuu4A2SHO1PnOcU84hwGjPWT7HcCH1vH6fp2tfKc3bfz1cDYf3uktzl2gc9o6KlTb\ntIE7AaNrVlEtlV1DtZmp8BpHR0b4Og7R4svo7icnH0B/H08+ymTCyVOxvBOPJAq1drRTrVrkO9wp\ncrxsLJEM3A0aGOyk2swcdx1mp8LJOwCwclm4ZuBn/+Un6Zy//eWvg+OZzOJr+OnOL0RCUfALkVAU\n/EIkFAW/EAlFwS9EQlHwC5FQFtOu61EAnwFw2t23Nsa+DuBPAbzTZ+ir7v6rhR6rq6MdH7vjw0Ft\n3ZZb6LyTJ04Ex1eu4FbZxg3rqbZ82RDV0s7tw2mS1FGKJL9Yij9eZwdP7Ons5BZbOsetyiyxTAuz\n4ZZQAHDbVm4djm0co1qlzm1MJ/eVap3bcp7m5yqd5Zdqpcj9wzpJdEll+H3PWvk6EJlXqvDzkUnz\n2pC1cvi6WhaxFe/6px8Jjj/34m4651IWc+f/KwD3Bcb/0t23Nf4tGPhCiOuLBYPf3Z8BwPNjhRAf\nSK7kPf+XzWyXmT1qZjzZWghxXfJ+g/87ANYB2AbgFIBvsh80s4fMbKeZ7ZyZ5cUOhBDN5X0Fv7tP\nuHvN3esAvgtgR+RnH3H37e6+vbODb2AIIZrL+wp+M7s4S+RzAPZcneUIIZrFYqy+HwL4OIBBMzsO\n4GsAPm5m2wA4gMMA/mwxB2tvb8OHb/5QULvxVm71FbaGbbuOHp5VxivFAW7cyklFLJn+jnAdtki3\nruira520kgIWqMUWsZRKpXC7rvU3rKZz2nLccizM8oxFT0UuHwtrHqmPV3eu1SLPWaxFVbkQPh+1\nOv+dU5nI9RF5RqcnueV75NAxqt15163B8bkKryfZTuzIiLP8HhYMfnf/QmD4e4s/hBDiekSf8BMi\noSj4hUgoCn4hEoqCX4iEouAXIqE0tYBnKpVCG8lk62zlLa862skyI8UKY4UiLWb1xSwlD1tz9Qq3\n7GL2lUWKSFYjZmXMznFSgLSzl2dAVmv8WLV6pCAkackFAI5acDwVW3yNa7UMt2AdkSebFIy1enh9\nANAS+Z2zNf6cdRT5PJ8IW44AcObgRHB81SZexPVsKvxp2cux+nTnFyKhKPiFSCgKfiESioJfiISi\n4BcioSj4hUgoTbX60uk0unrClpNHsunmSmG7xku8p1qJzAGA2ZlZqpUrfF6pFM6mq1a5VVaJZOBV\nIseai/R9m5vl2V5VkinY1d9D53T18L6GvV2DVGvNhfvxAUCN9V60SF89cK2rixc0nTzNz2OxELbE\n6nVefMrAf696jV9z3V3crl6zephqhbnw9eiRYqc9XWHLPB2xjy9Fd34hEoqCX4iEouAXIqEo+IVI\nKAp+IRJKU3f7p6by+NvH/y6o1bK/o/POnw8nPsxcOEvnpCK5HjEnYGIifCwAqJFsof5I+6++wQGq\ntaT56Z89F27hBAD739pHtfxMeHd7dC1vyZXOcqelu4uvf+1aXhdw1Wi43uHadSvpnP4WnpXS1crX\nWI/UckQ6nGxTqfGd9HSkJVc6ssbhsYgz0s2dgIqHk4zS3HRAf3/4d85Ekt0uRXd+IRKKgl+IhKLg\nFyKhKPiFSCgKfiESioJfiISymHZdowC+D2AY8+25HnH3b5tZP4AfAxjDfMuuB9z9fOyx8tMzePLp\nZ4Na76pNdJ7XwvbVK88+TeesWcXrnw0OcPvqxPFxqlVJ3bf2fp4YU07xpJ+J47yF0z077qDatptv\npNpcqRgcT2X5U33o6BGq7X/rbart3vMK1Xp7wk1Z//hffY7OufPGjVTLRXqirRoZpVqZWH0WKXYX\nq7tYIbUJASCVidQF7OWJSW0kGaee5pY0Mz4jJSjfw2Lu/FUAf+7uWwDcDuBLZrYFwMMAnnL3DQCe\nanwvhPiAsGDwu/spd3+58fU0gH0AVgK4H8BjjR97DMBnr9UihRBXn8t6z29mYwBuBfACgGF3P9WQ\nxjH/tkAI8QFh0cFvZp0AfgrgK+6ev1hzdwfCxdPN7CEz22lmO8tlXghBCNFcFhX8ZpbFfOD/wN1/\n1hieMLORhj4C4HRorrs/4u7b3X17Lsc/3yyEaC4LBr/Nt7f5HoB97v6ti6THATzY+PpBAL+4+ssT\nQlwrFpPVdyeAPwGw28xebYx9FcA3APzEzL4I4AiABxZ6oL7+AfzrL/yboNYytIHOm5sO229v7X6N\nzhlZzu2fVKTOWVsrzxAr18MtlzZu5WvvG+EZf3ODvI7cZz71R1Rr72qj2iyx+iKdtVAlbcgAoFgN\nPx4AnD59jmpHDp0Mjre38/M7fnySaof3vkW1VJGv8eB48A9S7PjkdjpnzdgKqsWyAVOtkTS8LLcB\njdXqMz4nZ+Hn7HKsvgWD391/D4A95D2LP5QQ4npCn/ATIqEo+IVIKAp+IRKKgl+IhKLgFyKhNLWA\npxnQkgu/3ux/Yw+dl78Qtvo8ln1V5hlRM5F2XRbxSlpbwrlUlTnePuvCGb7GiaM8q+/v/j5c6BQA\nzk9HjjdzITje1c0ttp6+cAs1AOiIFJ48fjxs5wHA0GC4UGdrN7c+f/dL/jufe2sX1Wpl3hLtwHi4\nIOvxSMuzDZu5ddvT3c61Pt4Sra2dZ/X1dISvq2wrL8bZ3h5+XtwX7/Xpzi9EQlHwC5FQFPxCJBQF\nvxAJRcEvREJR8AuRUJpq9dWrFUxPhm273/zil3TesfHjwfFUJZxlBwC7duWpFkt9qlZ51hZIJtWT\nT/yGTslluVW27dbbqFbOdVEtX5qj2sGj4Sy2yUne369c5Fl9J8cPU+3QYf6Y22/9cHD8333pP9A5\nLz7/HNWqF3jGX77Ei8QUwjVmcHAnt1l/99IpqnVkuK2YzXFrLt3Cr4MuYvWtWjNG59z/x58Pjper\ni7+f684vREJR8AuRUBT8QiQUBb8QCUXBL0RCaepufzabw8jwSFDbMLaWznOEd6MzkVZY6ciOfirN\nX/O8zhNxcq0dYSHLkzZWrAgnuADAx++9l2pd7ZEEklZe++/1PeG6hvsP8LZby1eOUa0YaZOVbuNr\n3LP/jeD46/v30zntY5updvIk/537erk2lAvX1Wvv5HUQz43z9mWTJw5Q7czZcBIRABRrkSQ0UmDx\n1BQPz4/eE55T5WX/3oPu/EIkFAW/EAlFwS9EQlHwC5FQFPxCJBQFvxAJZUGrz8xGAXwf8y24HcAj\n7v5tM/s6gD8FcKbxo19191/FHqtareLcmXCLp9v/yUfpvI9+7GPB8ZYWnkiRidh5sXZd9UjrqjTC\nx6uUub9SKPMknMnjh6h2rsgTSM6d5W2yDhJL7+TpcEIVAHQO8fZUaOE2puW41VeuhpNtnvzt7+mc\nNetvotpoP7dMW1P8Mm4niVWlIq/hdzC/l2qdXbwWYs15Utj4+RmqDQ6OBcfnKvxa/M1vXwyOT0/z\n+pSXshifvwrgz939ZTPrAvCSmT3Z0P7S3f/boo8mhLhuWEyvvlMATjW+njazfQD4y7AQ4gPBZb3n\nN7MxALcCeKEx9GUz22Vmj5oZ/5iVEOK6Y9HBb2adAH4K4CvungfwHQDrAGzD/F8G3yTzHjKznWa2\nc3qGv88SQjSXRQW/mWUxH/g/cPefAYC7T7h7zd3rAL4LYEdorrs/4u7b3X17VyevTiOEaC4LBr/N\nt7D5HoB97v6ti8YvztD5HADeckcIcd2xmN3+OwH8CYDdZvZqY+yrAL5gZtswb/8dBvBnCz1QKmXo\nIG2GJvNFOu+VXS8Fx4eG+DbD8NAg1SoVbqOdPz9FNRTDa8zU+eOtXMtttNE+/pfQif28jtzsDK9Z\nNzS8PDjePtBL56RbuX01V+DPy8jIaqqNnwzXXTw7GW4nBgAjKyJt1CKt2WZK/PwjE77eKnVuz7a0\nkexNAC2RbNHy5BmqIRWu0wcAwySrslziLefY6eBn6b0sZrf/9wBCv3HU0xdCXN/oE35CJBQFvxAJ\nRcEvREJR8AuRUBT8QiSUphbwTBnQkg1nKpWK3GJ79tmnguNe4TZUdzsv0Fip8OyrYoG3AMuQ18o1\nY6N0ztbbt1Bt/WpuA04dC1tlADB+/izVcm1ha2v9QNgCBIAzZ3jG2U2btlLtxps2Ue1H/+v7wfEM\nwgU1AaAyy5/PcplrHqta2Rp+rmPts8bWrqPa6WNv8mOleJZpWwc/3ubNG4PjxTn+vIyODAXHf5vj\nluKl6M4vREJR8AuRUBT8QiQUBb8QCUXBL0RCUfALkVCaavXV63XMFUhBy0hRzXs/9Znw45V5Flg6\nYufVa7wwoqe5XZPOhG2q1g5eyHJ8iluH01O8b925Al+/tfKimm++ejA4Pvkczzhbt5Zbdh+5YQPV\nypGMv7Zc2NrySEZlLIMwleaXKml1BwAo1Emfxxo/v2tWcauvODNJtS3dPBvwxZdeodrJI2H7sDDL\nr2+fOx8cL5d4xuel6M4vREJR8AuRUBT8QiQUBb8QCUXBL0RCUfALkVCam9WXMnR0hu2ynkjlwa5l\n4aynUsTWaI28ruWMZ5Z5G88GbGkPz6sXefbV9HSeaul2XjhzaD0vuLm+nWf1vXUo3KsPxi3MLCmq\nCgAnTh2l2sAgL6DKtHKB21elEi/uORvJ+CtFst8qpbC1nGnl9uzwimVUO3JqgmoTR8m5B1Cc4b/b\n23tfDY4PDPB1eF9/eDxS6PRSdOcXIqEo+IVIKAp+IRKKgl+IhKLgFyKhLLjbb2atAJ4B0NL4+b9x\n96+ZWT+AHwMYw3y7rgfcPZxt0KBeL2JumiSz1PnrUNY6g+MTE3wH9a3XD1OtNcN39HM9fJd9kLQH\nWzHYQ+dkIglLAz0DVIvkHqFY4Kd5aCjsIKxcEd4dBoBT4+NU279/H9XGymupxpyY6Wn+nM3N8Z30\n/AXumsR2+2vlcGJVuoUn4ezdw1u9xVpoDQ0NU23lzbwW4tCy8LzBZbzuYitZ/1P/+DSdcymLufOX\nAPxzd78F8+247zOz2wE8DOApd98A4KnG90KIDwgLBr/P885La7bxzwHcD+CxxvhjAD57TVYohLgm\nLOo9v5mlGx16TwN40t1fADDs7u+0kh0HwP/mEUJcdywq+N295u7bAKwCsMPMtl6iO0h3YDN7yMx2\nmtnO6WlSyEMI0XQua7ff3acAPA3gPgATZjYCAI3/T5M5j7j7dnff3tXFP1IphGguCwa/mS0zs97G\n120APgHgDQCPA3iw8WMPAvjFtVqkEOLqs5jEnhEAj5lZGvMvFj9x9yfM7DkAPzGzLwI4AuCBBR+p\n7qiTtkupyOtQphJOSukmrb8A4KXnf0u18QmeGGNZnuSyY8eHg+N33bGdzrlwgVtbu15+gWqzRZ7I\nsv/oMaodPHw4OF6Y42+53HkRvNZunlySz09TbZq0FJvNc5syUooPmTRXeyJ/Ua5YG7Yj+wZG6Jyh\nFdxiW3HrTVTrj9Twy8VqQzItkowFD8dLKtIy7FIWDH533wXg1sD4JIB7Fn0kIcR1hT7hJ0RCUfAL\nkVAU/EIkFAW/EAlFwS9EQrHLqfl1xQczO4N5WxAABgFwz615aB3vRut4Nx+0daxxd+7PXkRTg/9d\nBzbb6e7cINc6tA6t45quQ3/2C5FQFPxCJJSlDP5HlvDYF6N1vBut4938f7uOJXvPL4RYWvRnvxAJ\nZUmC38zuM7M3zeyAmS1Z7T8zO2xmu83sVTPb2cTjPmpmp81sz0Vj/Wb2pJm91fif98K6tuv4upmd\naJyTV83s001Yx6iZPW1mr5vZXjP7943xpp6TyDqaek7MrNXMXjSz1xrr+M+N8at7Pty9qf8ApAG8\nDWAdgByA1wBsafY6Gms5DGBwCY57N4DbAOy5aOy/Ani48fXDAP7LEq3j6wD+Y5PPxwiA2xpfdwHY\nD2BLs89JZB1NPSeYz27ubHydBfACgNuv9vlYijv/DgAH3P2gu5cB/AjzxUATg7s/A+DcJcNNL4hK\n1tF03P2Uu7/c+HoawD4AK9HkcxJZR1Pxea550dylCP6VAC6uRnEcS3CCGziAX5vZS2b20BKt4R2u\np4KoXzazXY23Bdf87cfFmNkY5utHLGmR2EvWATT5nDSjaG7SN/zu8vnCpJ8C8CUzu3upFwTEC6I2\nge9g/i3ZNgCnAHyzWQc2s04APwXwFXd/V5eOZp6TwDqafk78CormLpalCP4TAEYv+n5VY6zpuPuJ\nxv+nAfwc829JlopFFUS91rj7ROPCqwP4Lpp0Tswsi/mA+4G7/6wx3PRzElrHUp2TxrEvu2juYlmK\n4P8DgA1mttbMcgA+j/lioE3FzDrMrOudrwF8EsCe+KxrynVREPWdi6vB59CEc2JmBuB7APa5+7cu\nkpp6Ttg6mn1OmlY0t1k7mJfsZn4a8zupbwP4iyVawzrMOw2vAdjbzHUA+CHm/3ysYH7P44sABjDf\n9uwtAL8G0L9E6/hrALsB7GpcbCNNWMddmP8TdheAVxv/Pt3scxJZR1PPCYCbAbzSON4eAP+pMX5V\nz4c+4SdEQkn6hp8QiUXBL0RCUfALkVAU/EIkFAW/EAlFwS9EQlHwC5FQFPxCJJT/ByGKsM3TKcRx\nAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(x_train[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We want to normalize pixel values (0-255) to unit interval."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x_train = x_train/255.\n",
"x_test = x_test/255."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Densenet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The pieces"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's make some helper functions for piecing together our network using Keras' Functional API."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These components should all be familiar to you:\n",
"* Relu activation\n",
"* Dropout regularization\n",
"* Batch-normalization"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def relu(x): return Activation('relu')(x)\n",
"def dropout(x, p): return Dropout(p)(x) if p else x\n",
"def bn(x): return BatchNormalization(mode=0, axis=-1)(x)\n",
"def relu_bn(x): return relu(bn(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional layer:\n",
"* L2 Regularization\n",
"* 'same' border mode returns same width/height\n",
"* Pass output through Dropout\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def conv(x, nf, sz, wd, p):\n",
" x = Convolution2D(nf, sz, sz, init='he_uniform', border_mode='same', \n",
" W_regularizer=l2(wd))(x)\n",
" return dropout(x,p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define ConvBlock as sequence:\n",
"* Batchnorm\n",
"* ReLU Activation\n",
"* Conv layer (conv w/ Dropout)\n",
"\n",
"The authors also use something called a *bottleneck* layer to reduce dimensionality of inputs. \n",
"\n",
"Recall that the filter space dimensionality grows at each block. The input dimensionality will determine the dimensionality of your convolution weight matrices, i.e. # of parameters.\n",
"\n",
"At size 3x3 or larger, convolutions can become extremely costly and # of parameters can increase quickly as a function of the input feature (filter) space. Therefore, a smart approach is to reduce dimensionality of filters by using a 1x1 convolution w/ smaller # of filters before the larger convolution.\n",
"\n",
"Bottleneck consists of:\n",
"* 1x1 conv\n",
"* Compress # of filters into growth factor `nf` * 4\n",
"* Batchnorm -> ReLU"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def conv_block(x, nf, bottleneck=False, p=None, wd=0):\n",
" x = relu_bn(x)\n",
" if bottleneck: x = relu_bn(conv(x, nf * 4, 1, wd, p))\n",
" return conv(x, nf, 3, wd, p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can define the dense block:\n",
"* Take given input `x`\n",
"* Pass through a conv block for output `b`\n",
"* Concatenate input `x` and conv block output `b`\n",
"* Set concatenation as new input `x` for next block\n",
"* Repeat"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def dense_block(x, nb_layers, growth_rate, bottleneck=False, p=None, wd=0):\n",
" if bottleneck: nb_layers //= 2\n",
" for i in range(nb_layers):\n",
" b = conv_block(x, growth_rate, bottleneck=bottleneck, p=p, wd=wd)\n",
" x = merge([x,b], mode='concat', concat_axis=-1)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As typical for CV architectures, we'll do some pooling after computation.\n",
"\n",
"We'll define this unit as the transition block, and we'll put one between each dense block.\n",
"\n",
"Aside from BN -> ReLU and Average Pooling, there is also an option for filter *compression* in this block. This is simply feature reduction via 1x1 conv as discussed before, where the new # of filters is a percentage of the incoming # of filters.\n",
"\n",
"Together with bottleneck, compression has been shown to improve performance and computational efficiency of DenseNet architectures. (the authors call this DenseNet-BC)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def transition_block(x, compression=1.0, p=None, wd=0):\n",
" nf = int(x.get_shape().as_list()[-1] * compression)\n",
" x = relu_bn(x)\n",
" x = conv(x, nf, 1, wd, p)\n",
" return AveragePooling2D((2, 2), strides=(2, 2))(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build the DenseNet model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've now defined all the building blocks (literally) to put together a DenseNet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- nb_classes: number of classes\n",
"- img_input: tuple of shape (channels, rows, columns) or (rows, columns, channels)\n",
"- depth: total number of layers \n",
" - Includes 4 extra non-block layers\n",
" - 1 input conv, 3 output layers\n",
"- nb_block: number of dense blocks (generally = 3). \n",
" - NOTE: Layers / block are evenly allocated. Therefore nb_block must be a factor of (Depth - 4)\n",
"- growth_rate: number of filters to add per dense block\n",
"- nb_filter: initial number of filters\n",
"- bottleneck: add bottleneck blocks\n",
"- Compression: Filter compression factor in transition blocks.\n",
"- p: dropout rate\n",
"- wd: weight decay\n",
"- activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'. Note that if sigmoid is used, classes must be 1.\n",
"\n",
"Returns: keras tensor with nb_layers of conv_block appended"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From start to finish, this generates:\n",
"* Conv input layer\n",
"* Alternate between Dense/Transition blocks `nb_block` times, ommitting Transition block after last Dense block\n",
" * Each Dense block has `(Depth-4)/nb_block` layers\n",
"* Pass final Dense block to BN -> ReLU\n",
"* Global Avg Pooling\n",
"* Dense layer w/ desired output activation"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def create_dense_net(nb_classes, img_input, depth=40, nb_block=3, \n",
" growth_rate=12, nb_filter=16, bottleneck=False, compression=1.0, p=None, wd=0, activation='softmax'):\n",
" \n",
" assert activation == 'softmax' or activation == 'sigmoid'\n",
" assert (depth - 4) % nb_block == 0\n",
" nb_layers_per_block = int((depth - 4) / nb_block)\n",
" nb_layers = [nb_layers_per_block] * nb_block\n",
"\n",
" x = conv(img_input, nb_filter, 3, wd, 0)\n",
" for i,block in enumerate(nb_layers):\n",
" x = dense_block(x, block, growth_rate, bottleneck=bottleneck, p=p, wd=wd)\n",
" if i != len(nb_layers)-1:\n",
" x = transition_block(x, compression=compression, p=p, wd=wd)\n",
"\n",
" x = relu_bn(x)\n",
" x = GlobalAveragePooling2D()(x)\n",
" return Dense(nb_classes, activation=activation, W_regularizer=l2(wd))(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"### Train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can test it out on CIFAR-10."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"input_shape = (32,32,3)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"img_input = Input(shape=input_shape)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
"source": [
"x = create_dense_net(10, img_input, depth=100, nb_filter=16, compression=0.5, \n",
" bottleneck=True, p=0.2, wd=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model = Model(img_input, x)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"model.compile(loss='sparse_categorical_crossentropy', \n",
" optimizer=keras.optimizers.SGD(0.1, 0.9, nesterov=True), metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"parms = {'verbose': 2, 'callbacks': [TQDMNotebookCallback()]}"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This will likely need to run overnight + lr annealing..."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/20\n",
"561s - loss: 1.9801 - acc: 0.4810 - val_loss: 2.0473 - val_acc: 0.5045\n",
"Epoch 2/20\n",
"556s - loss: 1.4368 - acc: 0.6571 - val_loss: 1.8446 - val_acc: 0.5864\n",
"Epoch 3/20\n",
"547s - loss: 1.2204 - acc: 0.7122 - val_loss: 1.3181 - val_acc: 0.6696\n",
"Epoch 4/20\n",
"556s - loss: 1.0634 - acc: 0.7547 - val_loss: 1.3620 - val_acc: 0.6658\n",
"Epoch 5/20\n",
"560s - loss: 0.9536 - acc: 0.7829 - val_loss: 2.6235 - val_acc: 0.4702\n",
"Epoch 6/20\n",
"557s - loss: 0.8835 - acc: 0.8025 - val_loss: 2.4969 - val_acc: 0.4981\n",
"Epoch 7/20\n",
"551s - loss: 0.8293 - acc: 0.8155 - val_loss: 1.1944 - val_acc: 0.7281\n",
"Epoch 8/20\n",
"551s - loss: 0.7949 - acc: 0.8244 - val_loss: 1.1396 - val_acc: 0.7366\n",
"Epoch 9/20\n",
"551s - loss: 0.7620 - acc: 0.8340 - val_loss: 1.9196 - val_acc: 0.5916\n",
"Epoch 10/20\n",
"551s - loss: 0.7472 - acc: 0.8389 - val_loss: 2.6207 - val_acc: 0.4900\n",
"Epoch 11/20\n",
"550s - loss: 0.7251 - acc: 0.8449 - val_loss: 1.4957 - val_acc: 0.6859\n",
"Epoch 12/20\n",
"551s - loss: 0.7117 - acc: 0.8503 - val_loss: 1.0381 - val_acc: 0.7751\n",
"Epoch 13/20\n",
"552s - loss: 0.7006 - acc: 0.8547 - val_loss: 1.6471 - val_acc: 0.6685\n",
"Epoch 14/20\n",
"556s - loss: 0.6945 - acc: 0.8555 - val_loss: 0.9267 - val_acc: 0.8087\n",
"Epoch 15/20\n",
"551s - loss: 0.6859 - acc: 0.8592 - val_loss: 1.0987 - val_acc: 0.7642\n",
"Epoch 16/20\n",
"550s - loss: 0.6756 - acc: 0.8645 - val_loss: 0.9704 - val_acc: 0.7940\n",
"Epoch 17/20\n",
"551s - loss: 0.6730 - acc: 0.8642 - val_loss: 0.9401 - val_acc: 0.7800\n",
"Epoch 18/20\n",
"551s - loss: 0.6666 - acc: 0.8700 - val_loss: 0.9759 - val_acc: 0.7830\n",
"Epoch 19/20\n",
"550s - loss: 0.6654 - acc: 0.8709 - val_loss: 0.8896 - val_acc: 0.8044\n",
"Epoch 20/20\n",
"551s - loss: 0.6617 - acc: 0.8712 - val_loss: 1.1052 - val_acc: 0.7570\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 20, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/4\n",
"550s - loss: 0.5463 - acc: 0.9128 - val_loss: 0.5737 - val_acc: 0.9033\n",
"Epoch 2/4\n",
"551s - loss: 0.4833 - acc: 0.9311 - val_loss: 0.5695 - val_acc: 0.9033\n",
"Epoch 3/4\n",
"551s - loss: 0.4575 - acc: 0.9366 - val_loss: 0.5590 - val_acc: 0.9051\n",
"Epoch 4/4\n",
"550s - loss: 0.4361 - acc: 0.9429 - val_loss: 0.5656 - val_acc: 0.9048\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 4, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.1)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"\n",
"Epoch 1/20\n",
"551s - loss: 0.6589 - acc: 0.8728 - val_loss: 1.3259 - val_acc: 0.6935\n",
"Epoch 2/20\n",
"551s - loss: 0.6510 - acc: 0.8766 - val_loss: 0.9672 - val_acc: 0.7880\n",
"Epoch 3/20\n",
"551s - loss: 0.6508 - acc: 0.8784 - val_loss: 1.1104 - val_acc: 0.7581\n",
"Epoch 4/20\n",
"551s - loss: 0.6462 - acc: 0.8793 - val_loss: 1.0601 - val_acc: 0.7877\n",
"Epoch 5/20\n",
"550s - loss: 0.6456 - acc: 0.8816 - val_loss: 0.9799 - val_acc: 0.7876\n",
"Epoch 6/20\n",
"551s - loss: 0.6427 - acc: 0.8830 - val_loss: 0.9377 - val_acc: 0.8028\n",
"Epoch 7/20\n",
"551s - loss: 0.6409 - acc: 0.8837 - val_loss: 1.8484 - val_acc: 0.5932\n",
"Epoch 8/20\n",
"551s - loss: 0.6378 - acc: 0.8831 - val_loss: 1.1806 - val_acc: 0.7420\n",
"Epoch 9/20\n",
"550s - loss: 0.6381 - acc: 0.8843 - val_loss: 1.0799 - val_acc: 0.7774\n",
"Epoch 10/20\n",
"551s - loss: 0.6344 - acc: 0.8870 - val_loss: 0.9114 - val_acc: 0.8163\n",
"Epoch 11/20\n",
"561s - loss: 0.6394 - acc: 0.8858 - val_loss: 0.9710 - val_acc: 0.7982\n",
"Epoch 12/20\n",
"560s - loss: 0.6367 - acc: 0.8863 - val_loss: 0.8751 - val_acc: 0.8249\n",
"Epoch 13/20\n",
"561s - loss: 0.6230 - acc: 0.8899 - val_loss: 1.2588 - val_acc: 0.7254\n",
"Epoch 14/20\n",
"561s - loss: 0.6298 - acc: 0.8895 - val_loss: 0.9942 - val_acc: 0.7801\n",
"Epoch 15/20\n",
"560s - loss: 0.6321 - acc: 0.8888 - val_loss: 0.8516 - val_acc: 0.8378\n",
"Epoch 16/20\n",
"559s - loss: 0.6268 - acc: 0.8893 - val_loss: 0.8288 - val_acc: 0.8301\n",
"Epoch 17/20\n",
"561s - loss: 0.6279 - acc: 0.8904 - val_loss: 1.2768 - val_acc: 0.7219\n",
"Epoch 18/20\n",
"561s - loss: 0.6248 - acc: 0.8920 - val_loss: 0.9362 - val_acc: 0.8015\n",
"Epoch 19/20\n",
"561s - loss: 0.6184 - acc: 0.8941 - val_loss: 0.9204 - val_acc: 0.8181\n",
"Epoch 20/20\n",
"561s - loss: 0.6254 - acc: 0.8915 - val_loss: 1.0211 - val_acc: 0.7706\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 20, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/40\n",
"556s - loss: 0.5141 - acc: 0.9320 - val_loss: 0.5652 - val_acc: 0.9165\n",
"Epoch 2/40\n",
"560s - loss: 0.4530 - acc: 0.9477 - val_loss: 0.5451 - val_acc: 0.9199\n",
"Epoch 3/40\n",
"560s - loss: 0.4290 - acc: 0.9546 - val_loss: 0.5409 - val_acc: 0.9188\n",
"Epoch 4/40\n",
"559s - loss: 0.4101 - acc: 0.9584 - val_loss: 0.5259 - val_acc: 0.9224\n",
"Epoch 5/40\n",
"549s - loss: 0.3934 - acc: 0.9620 - val_loss: 0.5365 - val_acc: 0.9198\n",
"Epoch 6/40\n",
"551s - loss: 0.3813 - acc: 0.9631 - val_loss: 0.5150 - val_acc: 0.9209\n",
"Epoch 7/40\n",
"556s - loss: 0.3685 - acc: 0.9644 - val_loss: 0.5238 - val_acc: 0.9197\n",
"Epoch 8/40\n",
"556s - loss: 0.3565 - acc: 0.9668 - val_loss: 0.5188 - val_acc: 0.9204\n",
"Epoch 9/40\n",
"555s - loss: 0.3430 - acc: 0.9693 - val_loss: 0.5078 - val_acc: 0.9206\n",
"Epoch 10/40\n",
"553s - loss: 0.3325 - acc: 0.9707 - val_loss: 0.5107 - val_acc: 0.9191\n",
"Epoch 11/40\n",
"556s - loss: 0.3220 - acc: 0.9721 - val_loss: 0.5091 - val_acc: 0.9191\n",
"Epoch 12/40\n",
"556s - loss: 0.3121 - acc: 0.9738 - val_loss: 0.5033 - val_acc: 0.9212\n",
"Epoch 13/40\n",
"556s - loss: 0.3082 - acc: 0.9723 - val_loss: 0.4970 - val_acc: 0.9226\n",
"Epoch 14/40\n",
"556s - loss: 0.2986 - acc: 0.9749 - val_loss: 0.5553 - val_acc: 0.9058\n",
"Epoch 15/40\n",
"555s - loss: 0.2913 - acc: 0.9746 - val_loss: 0.5065 - val_acc: 0.9203\n",
"Epoch 16/40\n",
"552s - loss: 0.2824 - acc: 0.9762 - val_loss: 0.4912 - val_acc: 0.9218\n",
"Epoch 17/40\n",
"554s - loss: 0.2774 - acc: 0.9764 - val_loss: 0.5191 - val_acc: 0.9125\n",
"Epoch 18/40\n",
"554s - loss: 0.2722 - acc: 0.9769 - val_loss: 0.5023 - val_acc: 0.9184\n",
"Epoch 19/40\n",
"550s - loss: 0.2654 - acc: 0.9771 - val_loss: 0.4965 - val_acc: 0.9183\n",
"Epoch 20/40\n",
"547s - loss: 0.2603 - acc: 0.9778 - val_loss: 0.5552 - val_acc: 0.9061\n",
"Epoch 21/40\n",
"547s - loss: 0.2549 - acc: 0.9779 - val_loss: 0.4868 - val_acc: 0.9168\n",
"Epoch 22/40\n",
"547s - loss: 0.2494 - acc: 0.9793 - val_loss: 0.4754 - val_acc: 0.9242\n",
"Epoch 23/40\n",
"547s - loss: 0.2462 - acc: 0.9785 - val_loss: 0.5014 - val_acc: 0.9136\n",
"Epoch 24/40\n",
"548s - loss: 0.2427 - acc: 0.9792 - val_loss: 0.5226 - val_acc: 0.9075\n",
"Epoch 25/40\n",
"547s - loss: 0.2376 - acc: 0.9794 - val_loss: 0.4829 - val_acc: 0.9159\n",
"Epoch 26/40\n",
"547s - loss: 0.2325 - acc: 0.9800 - val_loss: 0.5066 - val_acc: 0.9125\n",
"Epoch 27/40\n",
"548s - loss: 0.2312 - acc: 0.9790 - val_loss: 0.4887 - val_acc: 0.9155\n",
"Epoch 28/40\n",
"548s - loss: 0.2277 - acc: 0.9792 - val_loss: 0.4959 - val_acc: 0.9107\n",
"Epoch 29/40\n",
"547s - loss: 0.2255 - acc: 0.9788 - val_loss: 0.6025 - val_acc: 0.8956\n",
"Epoch 30/40\n",
"548s - loss: 0.2216 - acc: 0.9798 - val_loss: 0.4708 - val_acc: 0.9180\n",
"Epoch 31/40\n",
"548s - loss: 0.2238 - acc: 0.9772 - val_loss: 0.5193 - val_acc: 0.9084\n",
"Epoch 32/40\n",
"548s - loss: 0.2174 - acc: 0.9790 - val_loss: 0.5216 - val_acc: 0.9100\n",
"Epoch 33/40\n",
"547s - loss: 0.2176 - acc: 0.9782 - val_loss: 0.4960 - val_acc: 0.9153\n",
"Epoch 34/40\n",
"548s - loss: 0.2128 - acc: 0.9790 - val_loss: 0.4644 - val_acc: 0.9188\n",
"Epoch 35/40\n",
"548s - loss: 0.2113 - acc: 0.9795 - val_loss: 0.4759 - val_acc: 0.9196\n",
"Epoch 36/40\n",
"547s - loss: 0.2090 - acc: 0.9789 - val_loss: 0.5176 - val_acc: 0.9066\n",
"Epoch 37/40\n",
"548s - loss: 0.2078 - acc: 0.9802 - val_loss: 0.4602 - val_acc: 0.9208\n",
"Epoch 38/40\n",
"547s - loss: 0.2112 - acc: 0.9772 - val_loss: 0.4998 - val_acc: 0.9096\n",
"Epoch 39/40\n",
"548s - loss: 0.2051 - acc: 0.9794 - val_loss: 0.5156 - val_acc: 0.9066\n",
"Epoch 40/40\n",
"547s - loss: 0.2046 - acc: 0.9781 - val_loss: 0.4961 - val_acc: 0.9108\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 40, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.001)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/20\n",
"547s - loss: 0.1885 - acc: 0.9845 - val_loss: 0.4287 - val_acc: 0.9256\n",
"Epoch 2/20\n",
"548s - loss: 0.1772 - acc: 0.9886 - val_loss: 0.4198 - val_acc: 0.9279\n",
"Epoch 3/20\n",
"547s - loss: 0.1734 - acc: 0.9901 - val_loss: 0.4181 - val_acc: 0.9283\n",
"Epoch 4/20\n",
"547s - loss: 0.1706 - acc: 0.9910 - val_loss: 0.4188 - val_acc: 0.9280\n",
"Epoch 5/20\n",
"548s - loss: 0.1679 - acc: 0.9918 - val_loss: 0.4127 - val_acc: 0.9298\n",
"Epoch 6/20\n",
"548s - loss: 0.1670 - acc: 0.9921 - val_loss: 0.4159 - val_acc: 0.9301\n",
"Epoch 7/20\n",
"548s - loss: 0.1650 - acc: 0.9926 - val_loss: 0.4139 - val_acc: 0.9300\n",
"Epoch 8/20\n",
"547s - loss: 0.1631 - acc: 0.9933 - val_loss: 0.4087 - val_acc: 0.9304\n",
"Epoch 9/20\n",
"548s - loss: 0.1619 - acc: 0.9934 - val_loss: 0.4150 - val_acc: 0.9302\n",
"Epoch 10/20\n",
"547s - loss: 0.1609 - acc: 0.9939 - val_loss: 0.4154 - val_acc: 0.9294\n",
"Epoch 11/20\n",
"547s - loss: 0.1611 - acc: 0.9933 - val_loss: 0.4102 - val_acc: 0.9310\n",
"Epoch 12/20\n",
"547s - loss: 0.1584 - acc: 0.9943 - val_loss: 0.4105 - val_acc: 0.9306\n",
"Epoch 13/20\n",
"547s - loss: 0.1594 - acc: 0.9934 - val_loss: 0.4093 - val_acc: 0.9309\n",
"Epoch 14/20\n",
"547s - loss: 0.1582 - acc: 0.9940 - val_loss: 0.4110 - val_acc: 0.9298\n",
"Epoch 15/20\n",
"547s - loss: 0.1567 - acc: 0.9942 - val_loss: 0.4080 - val_acc: 0.9315\n",
"Epoch 16/20\n",
"547s - loss: 0.1565 - acc: 0.9940 - val_loss: 0.4113 - val_acc: 0.9304\n",
"Epoch 17/20\n",
"548s - loss: 0.1558 - acc: 0.9942 - val_loss: 0.4093 - val_acc: 0.9292\n",
"Epoch 18/20\n",
"548s - loss: 0.1561 - acc: 0.9939 - val_loss: 0.4079 - val_acc: 0.9310\n",
"Epoch 19/20\n",
"548s - loss: 0.1552 - acc: 0.9942 - val_loss: 0.4153 - val_acc: 0.9297\n",
"Epoch 20/20\n",
"547s - loss: 0.1535 - acc: 0.9951 - val_loss: 0.4069 - val_acc: 0.9313\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 20, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/10\n",
"548s - loss: 0.1819 - acc: 0.9842 - val_loss: 0.4929 - val_acc: 0.9092\n",
"Epoch 2/10\n",
"547s - loss: 0.2018 - acc: 0.9751 - val_loss: 0.5761 - val_acc: 0.8880\n",
"Epoch 3/10\n",
"548s - loss: 0.2046 - acc: 0.9742 - val_loss: 0.5411 - val_acc: 0.8950\n",
"Epoch 4/10\n",
"548s - loss: 0.2008 - acc: 0.9765 - val_loss: 0.5607 - val_acc: 0.8957\n",
"Epoch 5/10\n",
"548s - loss: 0.1956 - acc: 0.9778 - val_loss: 0.4991 - val_acc: 0.9049\n",
"Epoch 6/10\n",
"548s - loss: 0.1996 - acc: 0.9760 - val_loss: 0.4714 - val_acc: 0.9112\n",
"Epoch 7/10\n",
"548s - loss: 0.1947 - acc: 0.9779 - val_loss: 0.5921 - val_acc: 0.8855\n",
"Epoch 8/10\n",
"547s - loss: 0.1958 - acc: 0.9770 - val_loss: 0.5096 - val_acc: 0.9058\n",
"Epoch 9/10\n",
"547s - loss: 0.1976 - acc: 0.9754 - val_loss: 0.5129 - val_acc: 0.9041\n",
"Epoch 10/10\n",
"548s - loss: 0.1940 - acc: 0.9767 - val_loss: 0.5693 - val_acc: 0.8869\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 10, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K.set_value(model.optimizer.lr, 0.001)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/20\n",
"548s - loss: 0.1879 - acc: 0.9801 - val_loss: 0.4073 - val_acc: 0.9270\n",
"Epoch 2/20\n",
"548s - loss: 0.1631 - acc: 0.9893 - val_loss: 0.4040 - val_acc: 0.9265\n",
"Epoch 3/20\n",
"547s - loss: 0.1601 - acc: 0.9905 - val_loss: 0.4007 - val_acc: 0.9295\n",
"Epoch 4/20\n",
"547s - loss: 0.1560 - acc: 0.9919 - val_loss: 0.4016 - val_acc: 0.9294\n",
"Epoch 5/20\n",
"548s - loss: 0.1540 - acc: 0.9921 - val_loss: 0.3988 - val_acc: 0.9293\n",
"Epoch 6/20\n",
"547s - loss: 0.1529 - acc: 0.9926 - val_loss: 0.4013 - val_acc: 0.9283\n",
"Epoch 7/20\n",
"548s - loss: 0.1497 - acc: 0.9937 - val_loss: 0.3984 - val_acc: 0.9312\n",
"Epoch 8/20\n",
"548s - loss: 0.1508 - acc: 0.9929 - val_loss: 0.3993 - val_acc: 0.9304\n",
"Epoch 9/20\n",
"547s - loss: 0.1486 - acc: 0.9937 - val_loss: 0.3988 - val_acc: 0.9303\n",
"Epoch 10/20\n",
"547s - loss: 0.1471 - acc: 0.9938 - val_loss: 0.3978 - val_acc: 0.9302\n",
"Epoch 11/20\n",
"547s - loss: 0.1460 - acc: 0.9942 - val_loss: 0.3945 - val_acc: 0.9306\n",
"Epoch 12/20\n",
"547s - loss: 0.1453 - acc: 0.9943 - val_loss: 0.3988 - val_acc: 0.9292\n",
"Epoch 13/20\n",
"547s - loss: 0.1456 - acc: 0.9939 - val_loss: 0.4004 - val_acc: 0.9298\n",
"Epoch 14/20\n",
"547s - loss: 0.1434 - acc: 0.9946 - val_loss: 0.3978 - val_acc: 0.9314\n",
"Epoch 15/20\n",
"547s - loss: 0.1427 - acc: 0.9946 - val_loss: 0.3974 - val_acc: 0.9311\n",
"Epoch 16/20\n",
"547s - loss: 0.1417 - acc: 0.9949 - val_loss: 0.3978 - val_acc: 0.9320\n",
"Epoch 17/20\n",
"548s - loss: 0.1403 - acc: 0.9954 - val_loss: 0.4010 - val_acc: 0.9317\n",
"Epoch 18/20\n",
"548s - loss: 0.1395 - acc: 0.9955 - val_loss: 0.3989 - val_acc: 0.9324\n",
"Epoch 19/20\n",
"547s - loss: 0.1409 - acc: 0.9951 - val_loss: 0.3997 - val_acc: 0.9312\n",
"Epoch 20/20\n",
"548s - loss: 0.1402 - acc: 0.9948 - val_loss: 0.3973 - val_acc: 0.9323\n",
"\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, 64, 20, validation_data=(x_test, y_test), **parms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And we're able to replicate their state-of-the-art results!"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 31.1 s, sys: 452 ms, total: 31.6 s\n",
"Wall time: 31.1 s\n"
]
}
],
"source": [
"%time model.save_weights('models/93.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"## End"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"hidden": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}