{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST CNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/ubuntu/.fastai/data/mnist_png/training'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/testing')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "il = ImageList.from_folder(path, convert_mode='L')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/4/44688.png')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "il.items[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "defaults.cmap='binary'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ImageList (70000 items)\n", "[Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28)]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "il" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABLZJREFUeJzt3b9LlW0cx3HNh4aEaAijQoSg1lZFGoLqL2hwaGvrH3CXpgKJFBrEwcXB/oeGoGiLIJwaEpGw3JxKy2d8eD7ne8M5cX55fL3GL4fba+jthVfnvu/xk5OTMeA/5wa9ABg2ooAgCgiigCAKCKKAIAoI//T55/lPEYbJeDW0U0AQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQOj30zw4pT58+FDO5+bmyvnbt2/L+Z07d7q2pl6xU0AQBQRRQBAFBFFAcPpEW16+fFnOx8fLx7GOrayslHOnT3AKiQKCKCCIAoIoIDh94n9+/PhRzt+9e9fRda5evdqN5QyEnQKCKCCIAoIoIIgCgtOnNn379q2cf/r0qZzfv3+/ZTYxMdHVNfXC9vZ2Od/d3e3oOo8ePerGcgbCTgFBFBBEAUEUEEQBwelTm169elXOnz59Ws4PDw9bZpOTk11dUy+8fv160EsYODsFBFFAEAUEUUAQBQSnT2F5ebmcP3v2rM8r6b2jo6OW2efPnzu6xvz8fDm/ffv2X61pGNgpIIgCgiggiAKCKCCc2dOnvb29cr66ulrOf/36Vc5Pw910Td68edMya3pXXZO7d++W8/Pnz//VmoaBnQKCKCCIAoIoIIz8H9rHx8flfG1trZx//fq1o+tvbW2V89NwQ1EnD02+fv16OX/8+HG3ljM07BQQRAFBFBBEAUEUEEb+9KnpNGlpaamj68zOzpbzBw8edLqkoXFwcND2Zy9fvlzOZ2ZmurWcoWGngCAKCKKAIAoIooAw8qdP3dL0eqvnz5+X83v37rV97Vu3bpXzqamptq/xN9bX19v+7PT0dA9XMlzsFBBEAUEUEEQBQRQQxk9OTvr58/r6w8bGxsa+fPlSzptOfAbh5s2b5bzXp0/v379vmTX9e9jc3CznCwsLXV1Tn41XQzsFBFFAEAUEUUAQBYSRP33a398v50130u3s7PRyOUPv2rVr5Xx7e7ucX7x4sZfL6TWnT9AOUUAQBQRRQBAFhJG/8+7KlSvl/OPHj+W86XlQL1686Nqahtnc3Fw5P+WnTB2xU0AQBQRRQBAFBFFAGPnvPnXqz58/5fz3798dXWd5eblldnR0VH626Ynely5dKudN39tqulPvxo0b5fz79+8ts9XV1fKzT548KeennO8+QTtEAUEUEEQBYeS/5tGpc+fq3xNN8yaLi4vdWE5HNjY2ynn1BzXN7BQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUENx5d4ZduHChZfbw4cMBrGS42CkgiAKCKCCIAoIoIDh9OsOqh0YfHByUn216ePMoslNAEAUEUUAQBQRRQPB6rxHy8+fPcj4zM1POvd7L672gLaKAIAoIooAgCghOnzjLnD5BO0QBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAGh36/3Kh8pAsPETgFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQB4V995LZxp9qfvwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "il[0].show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sd = il.split_by_folder(train='training', valid='testing')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ItemLists;\n", "\n", "Train: ImageList (60000 items)\n", "[Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28)]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png;\n", "\n", "Valid: ImageList (10000 items)\n", "[Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28)]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png;\n", "\n", "Test: None" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/4'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/6'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/8'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/0'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/9'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/1'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/3'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/2'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/5'),\n", " PosixPath('/home/ubuntu/.fastai/data/mnist_png/training/7')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(path/'training').ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ll = sd.label_from_folder()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LabelLists;\n", "\n", "Train: LabelList\n", "y: CategoryList (60000 items)\n", "[Category 4, Category 4, Category 4, Category 4, Category 4]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png\n", "x: ImageList (60000 items)\n", "[Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28)]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png;\n", "\n", "Valid: LabelList\n", "y: CategoryList (10000 items)\n", "[Category 4, Category 4, Category 4, Category 4, Category 4]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png\n", "x: ImageList (10000 items)\n", "[Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28), Image (1, 28, 28)]...\n", "Path: /home/ubuntu/.fastai/data/mnist_png;\n", "\n", "Test: None" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ll" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = ll.train[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4 torch.Size([1, 28, 28])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABLZJREFUeJzt3b9LlW0cx3HNh4aEaAijQoSg1lZFGoLqL2hwaGvrH3CXpgKJFBrEwcXB/oeGoGiLIJwaEpGw3JxKy2d8eD7ne8M5cX55fL3GL4fba+jthVfnvu/xk5OTMeA/5wa9ABg2ooAgCgiigCAKCKKAIAoI//T55/lPEYbJeDW0U0AQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQBAFBFFAEAUEUUAQBQRRQOj30zw4pT58+FDO5+bmyvnbt2/L+Z07d7q2pl6xU0AQBQRRQBAFBFFAcPpEW16+fFnOx8fLx7GOrayslHOnT3AKiQKCKCCIAoIoIDh94n9+/PhRzt+9e9fRda5evdqN5QyEnQKCKCCIAoIoIIgCgtOnNn379q2cf/r0qZzfv3+/ZTYxMdHVNfXC9vZ2Od/d3e3oOo8ePerGcgbCTgFBFBBEAUEUEEQBwelTm169elXOnz59Ws4PDw9bZpOTk11dUy+8fv160EsYODsFBFFAEAUEUUAQBQSnT2F5ebmcP3v2rM8r6b2jo6OW2efPnzu6xvz8fDm/ffv2X61pGNgpIIgCgiggiAKCKCCc2dOnvb29cr66ulrOf/36Vc5Pw910Td68edMya3pXXZO7d++W8/Pnz//VmoaBnQKCKCCIAoIoIIz8H9rHx8flfG1trZx//fq1o+tvbW2V89NwQ1EnD02+fv16OX/8+HG3ljM07BQQRAFBFBBEAUEUEEb+9KnpNGlpaamj68zOzpbzBw8edLqkoXFwcND2Zy9fvlzOZ2ZmurWcoWGngCAKCKKAIAoIooAw8qdP3dL0eqvnz5+X83v37rV97Vu3bpXzqamptq/xN9bX19v+7PT0dA9XMlzsFBBEAUEUEEQBQRQQxk9OTvr58/r6w8bGxsa+fPlSzptOfAbh5s2b5bzXp0/v379vmTX9e9jc3CznCwsLXV1Tn41XQzsFBFFAEAUEUUAQBYSRP33a398v50130u3s7PRyOUPv2rVr5Xx7e7ucX7x4sZfL6TWnT9AOUUAQBQRRQBAFhJG/8+7KlSvl/OPHj+W86XlQL1686Nqahtnc3Fw5P+WnTB2xU0AQBQRRQBAFBFFAGPnvPnXqz58/5fz3798dXWd5eblldnR0VH626Ynely5dKudN39tqulPvxo0b5fz79+8ts9XV1fKzT548KeennO8+QTtEAUEUEEQBYeS/5tGpc+fq3xNN8yaLi4vdWE5HNjY2ynn1BzXN7BQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUENx5d4ZduHChZfbw4cMBrGS42CkgiAKCKCCIAoIoIDh9OsOqh0YfHByUn216ePMoslNAEAUEUUAQBQRRQPB6rxHy8+fPcj4zM1POvd7L672gLaKAIAoIooAgCghOnzjLnD5BO0QBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAGh36/3Kh8pAsPETgFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQBQRQQRAFBFBBEAUEUEEQB4V995LZxp9qfvwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x.show()\n", "print(y,x.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = ([*rand_pad(padding=3, size=28, mode='zeros')], [])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ll = ll.transform(tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs = 128" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# not using imagenet_stats because not using pretrained model\n", "data = ll.databunch(bs=bs).normalize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = data.train_ds[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABKtJREFUeJzt3bFLVX0cx/Gsh4aCaIhEI4RAV9ckGoTsL3BoaGvrH2iXpgSJuuAgDi0N9T80BEVbBOHUUIRI5eakZjY+PJ/7vXFPnnu9Pb1e4xc59zf09oe/7jln7PDw8ATwr5PHvQAYNaKAIAoIooAgCgiigCAKCP8M+fP8pwijZKwa2ikgiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCCIAoIoIIgCgiggiAKCKCAM+01G/OXevHlTzufm5sr5y5cvu2bXr19vdU3JTgFBFBBEAUEUEEQBwekTQ/Xo0aNyPjZWvtL6xOPHj7tmTp9gyEQBQRQQRAFBFBCcPjEQ3759K+evXr1qdJ2JiYk2ltOInQKCKCCIAoIoIIgCgtOnIdra2irn7969K+cLCwvl/NSpU62taVA2NjbK+efPnxtd5/bt220spxE7BQRRQBAFBFFAEAUEp09DtLq6Ws7v379fznd2dsr52bNnW1vToDx//vy4l/Db7BQQRAFBFBBEAcEf2gOysrLSNXvw4MExrGSw9vf3y/n79+8bXefatWvlfHZ2tvGajspOAUEUEEQBQRQQRAHB6dMRbW5ulvNOp9M129vbK3/2T7hpqJcXL16U8+q1XL8yPz9fzk+fPt14TUdlp4AgCgiigCAKCKKA4PSpT9+/fy/na2tr5fzjx499X/vZs2fl/E+4majpA5MvXbpUzu/cudPGclphp4AgCgiigCAKCKKA4PSpT71Ok5aWlvq+xtWrV8v5zZs3f2dJI2F7e7vRz1+4cKGcT01NtbGcVtgpIIgCgiggiAKCKCA4fRqiXq+2Wl5eLuc3btxodP2ZmZlyfvHixUbXaWJ9fb3Rz1++fHlAK2mPnQKCKCCIAoIoIIgCwtjh4eEwP2+oH9amDx8+lPNeJz7HYXp6upwP8vTp9evX5bzXv6unT5+W81u3brW2pgbGqqGdAoIoIIgCgiggiAKC06c+ffnypZz3upvu06dPg1zOyJucnCznGxsb5fzcuXODXE4vTp+gH6KAIAoIooDgJqM+jY+Pl/O3b9+W8+rRNw8fPmx1TaNsbm6unB/TH9SN2CkgiAKCKCCIAoIoIPiax4D8+PGja3ZwcNDoGisrK+V8f3+/nPd6ePH58+fLefUVlV43JF25cqWcf/36tZx3Op1yfvfu3XJ+THzNA/ohCgiigCAKCKKA4LtPA3LyZPfvm2r2K/fu3WtrOX178uRJOe91yvR/ZKeAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigODOO47kzJkz5XxxcXHIK2mPnQKCKCCIAoIoIIgCgtMnjqTXk9S3t7fLea+nmo8SOwUEUUAQBQRRQPB6L/5jd3e3nE9NTZVzr/eCv4AoIIgCgiggiAKC0yf+Zk6foB+igCAKCKKAIAoIooAgCgiigCAKCKKAIAoIw37ETfldExgldgoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCKKAIAoIooAgCgiigCAKCD8Bxu62ceBPCLoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x.show()\n", "print(y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def _plot(i,j,ax): data.train_ds[0][0].show(ax, cmap='gray')\n", "plot_multi(_plot, 3, 3, figsize=(8,8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([128, 1, 28, 28]), torch.Size([128]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb,yb = data.one_batch()\n", "xb.shape,yb.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data.show_batch(rows=3, figsize=(5,5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic CNN with batchnorm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv(ni,nf): return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv(1, 8), # 14\n", " nn.BatchNorm2d(8),\n", " nn.ReLU(),\n", " conv(8, 16), # 7\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " conv(16, 32), # 4\n", " nn.BatchNorm2d(32),\n", " nn.ReLU(),\n", " conv(32, 16), # 2\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " conv(16, 10), # 1\n", " nn.BatchNorm2d(10),\n", " Flatten() # remove (1,1) grid\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "======================================================================\n", "Conv2d [128, 8, 14, 14] 80 True \n", "______________________________________________________________________\n", "BatchNorm2d [128, 8, 14, 14] 16 True \n", "______________________________________________________________________\n", "ReLU [128, 8, 14, 14] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 16, 7, 7] 1168 True \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 7, 7] 32 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 7, 7] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 32, 4, 4] 4640 True \n", "______________________________________________________________________\n", "BatchNorm2d [128, 32, 4, 4] 64 True \n", "______________________________________________________________________\n", "ReLU [128, 32, 4, 4] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 16, 2, 2] 4624 True \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 2, 2] 32 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 2, 2] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 10, 1, 1] 1450 True \n", "______________________________________________________________________\n", "BatchNorm2d [128, 10, 1, 1] 20 True \n", "______________________________________________________________________\n", "Flatten [128, 10] 0 False \n", "______________________________________________________________________\n", "\n", "Total params: 12126\n", "Total trainable params: 12126\n", "Total non-trainable params: 0\n", "\n" ] } ], "source": [ "print(learn.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xb = xb.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 10])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(xb).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find(end_lr=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:30

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.2231670.2178590.930500
20.1361790.0786510.976400
30.0720800.0386640.988600
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, max_lr=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv2(ni,nf): return conv_layer(ni,nf,stride=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv2(1, 8), # 14\n", " conv2(8, 16), # 7\n", " conv2(16, 32), # 4\n", " conv2(32, 16), # 2\n", " conv2(16, 10), # 1\n", " Flatten() # remove (1,1) grid\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 01:12

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.2283320.2063250.937500
20.1899660.1925580.940800
30.1567650.0928100.969100
40.1358710.0839140.973300
50.1088440.0715820.978000
60.1058870.1285860.960200
70.0806990.0527540.983200
80.0660070.0375880.988600
90.0475130.0302550.990200
100.0447050.0283730.991600
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, max_lr=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Resnet-ish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ResBlock(nn.Module):\n", " def __init__(self, nf):\n", " super().__init__()\n", " self.conv1 = conv_layer(nf,nf)\n", " self.conv2 = conv_layer(nf,nf)\n", " \n", " def forward(self, x): return x + self.conv2(self.conv1(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on function res_block in module fastai.layers:\n", "\n", "res_block(nf, dense:bool=False, norm_type:Union[fastai.layers.NormType, NoneType]=, bottle:bool=False, **kwargs)\n", " Resnet block of `nf` features.\n", "\n" ] } ], "source": [ "help(res_block)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv2(1, 8),\n", " res_block(8),\n", " conv2(8, 16),\n", " res_block(16),\n", " conv2(16, 32),\n", " res_block(32),\n", " conv2(32, 16),\n", " res_block(16),\n", " conv2(16, 10),\n", " Flatten()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv_and_res(ni,nf): return nn.Sequential(conv2(ni, nf), res_block(nf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv_and_res(1, 8),\n", " conv_and_res(8, 16),\n", " conv_and_res(16, 32),\n", " conv_and_res(32, 16),\n", " conv2(16, 10),\n", " Flatten()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find(end_lr=100)\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 02:00

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.2488070.1215820.972800
20.1209270.3605830.890300
30.1040130.0749160.977800
40.0811810.0657170.980000
50.0685140.0964480.967200
60.0612740.0879550.971800
70.0516730.0339110.989400
80.0480900.0332340.988800
90.0390950.0246380.992400
100.0236700.0212150.993400
110.0191280.0161590.994500
120.0213650.0161200.995200
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(12, max_lr=0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "======================================================================\n", "Conv2d [128, 8, 14, 14] 72 True \n", "______________________________________________________________________\n", "ReLU [128, 8, 14, 14] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 8, 14, 14] 16 True \n", "______________________________________________________________________\n", "Conv2d [128, 8, 14, 14] 576 True \n", "______________________________________________________________________\n", "ReLU [128, 8, 14, 14] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 8, 14, 14] 16 True \n", "______________________________________________________________________\n", "Conv2d [128, 8, 14, 14] 576 True \n", "______________________________________________________________________\n", "ReLU [128, 8, 14, 14] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 8, 14, 14] 16 True \n", "______________________________________________________________________\n", "MergeLayer [128, 8, 14, 14] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 16, 7, 7] 1152 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 7, 7] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 7, 7] 32 True \n", "______________________________________________________________________\n", "Conv2d [128, 16, 7, 7] 2304 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 7, 7] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 7, 7] 32 True \n", "______________________________________________________________________\n", "Conv2d [128, 16, 7, 7] 2304 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 7, 7] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 7, 7] 32 True \n", "______________________________________________________________________\n", "MergeLayer [128, 16, 7, 7] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 32, 4, 4] 4608 True \n", "______________________________________________________________________\n", "ReLU [128, 32, 4, 4] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 32, 4, 4] 64 True \n", "______________________________________________________________________\n", "Conv2d [128, 32, 4, 4] 9216 True \n", "______________________________________________________________________\n", "ReLU [128, 32, 4, 4] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 32, 4, 4] 64 True \n", "______________________________________________________________________\n", "Conv2d [128, 32, 4, 4] 9216 True \n", "______________________________________________________________________\n", "ReLU [128, 32, 4, 4] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 32, 4, 4] 64 True \n", "______________________________________________________________________\n", "MergeLayer [128, 32, 4, 4] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 16, 2, 2] 4608 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 2, 2] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 2, 2] 32 True \n", "______________________________________________________________________\n", "Conv2d [128, 16, 2, 2] 2304 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 2, 2] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 2, 2] 32 True \n", "______________________________________________________________________\n", "Conv2d [128, 16, 2, 2] 2304 True \n", "______________________________________________________________________\n", "ReLU [128, 16, 2, 2] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 16, 2, 2] 32 True \n", "______________________________________________________________________\n", "MergeLayer [128, 16, 2, 2] 0 False \n", "______________________________________________________________________\n", "Conv2d [128, 10, 1, 1] 1440 True \n", "______________________________________________________________________\n", "ReLU [128, 10, 1, 1] 0 False \n", "______________________________________________________________________\n", "BatchNorm2d [128, 10, 1, 1] 20 True \n", "______________________________________________________________________\n", "Flatten [128, 10] 0 False \n", "______________________________________________________________________\n", "\n", "Total params: 41132\n", "Total trainable params: 41132\n", "Total non-trainable params: 0\n", "\n" ] } ], "source": [ "print(learn.summary())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }