{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## MNIST CNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.vision.all import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#4) [/home/sgugger/.fastai/data/mnist_png/models,/home/sgugger/.fastai/data/mnist_png/testing,/home/sgugger/.fastai/data/mnist_png/resnet34_mnist.pkl,/home/sgugger/.fastai/data/mnist_png/training]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/sgugger/.fastai/data/mnist_png/testing/1/6540.png')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "items = get_image_files(path)\n", "items[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAEi0lEQVR4nO3dvUqcWxSA4eNh8kNAEFJIGovUWuQibIZ0XlAuJwETUsUqeBl2KRUsBCWk0YQ5N2D24HFG30+ep3QhbAgvC7LZnxuLxeIfoOffxz4AcDtxQpQ4IUqcECVOiJotmfuvXFi/jdt+aHNClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6Jmj30A7ub6+no4f/369XD+/v374fzjx493PhPrYXNClDghSpwQJU6IEidEiROixAlRG4vFYjQfDnl4Ozs7w/np6elwvrm5OZxfXl7e+Uzc28ZtP7Q5IUqcECVOiBInRIkTosQJUZ6MTczu7u5wfnZ29kAnYd1sTogSJ0SJE6LECVHihChxQpQ4IcqTsYn5/v37cD6fz4fzV69eDeeejD0KT8ZgSsQJUeKEKHFClDghSpwQJU6I8p5zYpa911xyb82E2JwQJU6IEidEiROixAlR4oQocUKU95xPzNbW1r1+33vOR+E9J0yJOCFKnBAlTogSJ0SJE6LECVHec07MxcXFcH5zczOcP3v2bJXHYY1sTogSJ0SJE6LECVHihChxQpSrlIk5PDwczn///j2cL/t05vn5+V9n29vbw99ltWxOiBInRIkTosQJUeKEKHFClDghyqcxn5hln8a8uroazj9//vzX2cHBwf86E0v5NCZMiTghSpwQJU6IEidEiROixAlR3nM+MfP5fDj/9OnTA52E+7I5IUqcECVOiBInRIkTosQJUeKEKPecT8yy79LOZuN/8nfv3q3yONyDzQlR4oQocUKUOCFKnBAlTogSJ0S555yY6+vr4fzbt2/D+Z8/f4bzHz9+/HX29u3b4e+yWjYnRIkTosQJUeKEKHFClDghylXKxJycnAznNzc3w/nm5uZwvr+/f+czsR42J0SJE6LECVHihChxQpQ4IUqcEOWec2K+fv06nC97UvbixYtVHoc1sjkhSpwQJU6IEidEiROixAlR4oQo95wTs7e399hH4IHYnBAlTogSJ0SJE6LECVHihChxQtTGYrEYzYdDHt7Pnz+H8zdv3gzns9n4avvy8vLOZ+LeNm77oc0JUeKEKHFClDghSpwQJU6I8mRsYpZ9+nI+nw/nFxcXqzwOa2RzQpQ4IUqcECVOiBInRIkTosQJUe45J+bDhw/D+ZcvX4bzX79+rfI4rJHNCVHihChxQpQ4IUqcECVOiBInRLnnnJjnz58P58fHx8P5y5cvV3kc1sjmhChxQpQ4IUqcECVOiBInRIkTovwJwInZ2dkZzo+Ojobz3d3dVR6H1fAnAGFKxAlR4oQocUKUOCFKnBAlTohyzwmPzz0nTIk4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6JmS+a3frIPWD+bE6LECVHihChxQpQ4IUqcEPUfI+WMO0DgEWoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "im = PILImageBW.create(items[0])\n", "im.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = GrandparentSplitter(train_name='training', valid_name='testing')(items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsets = Datasets(items, tfms=[[PILImageBW.create], [parent_label, Categorize]], splits=splits)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#70000) [(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1))...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dsets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#60000) [(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1))...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dsets.train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#10000) [(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1)),(, tensor(1))...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dsets.valid" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAFA0lEQVR4nO3dv2qVZxzA8ec1IhYEcXALgk7OUvAS6pLFuZfQwbFDcJDu3kL1Ajp6EVIkm5OIIJ0CGhBK/MPp0FIoJG+q5+j55uTzmUJ+JHlAvv5CHt5zpsViMYCec+s+AHA0cUKUOCFKnBAlTogSJ0SJE6LEuQGmafppmqbfp2k6nKbp13Wfh9U4v+4DsBJ/jDF+GWP8MMb4bs1nYUXEuQEWi8VvY4wxTdP3Y4ztNR+HFfFrLUSJE6LECVHihCh/ENoA0zSdH3//W26NMbamabo4xvi4WCw+rvdkLMPm3Ay7Y4w/xxg/jzF+/Ofj3bWeiKVNHraGJpsTosQJUeKEKHFC1ElXKf5aBF/fdNQnbU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQtT5dR+A//r06dPs/MWLF7Pze/fuzc6fPHny2WdiPWxOiBInRIkTosQJUeKEKHFClDghyj1nzOHh4ez85s2bs/Pt7e3Z+bt372bnly5dmp3z7dicECVOiBInRIkTosQJUeKEKFcpG+b169ez84ODg9m5q5QOmxOixAlR4oQocUKUOCFKnBAlTohyz7lhFovFuo/AiticECVOiBInRIkTosQJUeKEKHFClHvODTNN0+z8pJfepMPmhChxQpQ4IUqcECVOiBInRIkTotxznjF7e3uz8xs3bnyjk3ASmxOixAlR4oQocUKUOCFKnBAlTohyzxlz7tz8/5dXrlyZnb9582Z2/vz5888+E+thc0KUOCFKnBAlTogSJ0SJE6JcpcRcvHhxdr6zszM7f/z48SqPwxrZnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClOc5z5j9/f11H4H/yeaEKHFClDghSpwQJU6IEidEiROi3HOeMY8ePZqdP3z48BudhJPYnBAlTogSJ0SJE6LECVHihChxQpR7zlPmzp07s3Pvz7k5bE6IEidEiROixAlR4oQocUKUq5RT5vr160t9/fv372fnBwcHx84uX7681M/m89icECVOiBInRIkTosQJUeKEKHFClHvOU2Zra2upr18sFrPzDx8+LPX9WR2bE6LECVHihChxQpQ4IUqcECVOiJpOuPeavxQj59atW7Pzvb292fnu7u6xswcPHnzRmTjRdNQnbU6IEidEiROixAlR4oQocUKUOCHK85wb5u7du7Pzly9fzs7v37+/yuOwBJsTosQJUeKEKHFClDghSpwQ5SrljJmmI59O+teyL73J6ticECVOiBInRIkTosQJUeKEKHFClHvOM+bt27ez86dPnx47u3379qqPwwybE6LECVHihChxQpQ4IUqcECVOiPIWgBvm2rVrs/P9/f3Z+atXr46dXb169YvOxIm8BSCcJuKEKHFClDghSpwQJU6IEidEeZ5zw+zs7MzOnz17Nju/cOHCKo/DEmxOiBInRIkTosQJUeKEKHFClDghyvOcsH6e54TTRJwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUSe9BeCRL9kHfH02J0SJE6LECVHihChxQpQ4IeovA3x8nCTpVrIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_at(dsets.train, 0);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = [ToTensor(), CropPad(size=34, pad_mode=PadMode.Zeros), RandomCrop(size=28)]\n", "bs = 128" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = dsets.dataloaders(bs=bs, after_item=tfms, after_batch=[IntToFloatTensor, Normalize])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsrc1 = Datasets([items[0]]*128, tfms=[[PILImageBW.create], [parent_label, Categorize]], splits=[list(range(128)), []])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dbunch1 = dsrc1.dataloaders(bs=bs, after_item=tfms, after_batch=[IntToFloatTensor()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAHRCAYAAAABukKHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUh0lEQVR4nO3dvUtdW/oH8LV/GpjxrfIFwpQ2sb0jQ1IpQQX/gBQy4A3ccpjKYuDKLXQgkHaqgJBpJE2YP+D2VwNOHTAWdno7IeBMRvTuX3FPs/c6zzHmxrPPy+dTZT1uyCKs5MvKXvtZRVmWCQDI/V/TEwCAXiUkASAgJAEgICQBICAkASAgJAEgICQBICAk2yiK4i9FUfy7KIr/FUXxz6bnw+Cz5miCdXe70aYn0KPOUkp/TymtpZR+3/BcGA7WHE2w7m4hJNsoy/JfKaVUFMUfU0p/aHg6DAFrjiZYd7fz360AEBCSABAQkgAQEJIAEHBwp42iKEbTr382IymlkaIofpdSui7L8rrZmTGorDmaYN3dzk6yve2U0n9TSn9LKf259evtRmfEoLPmaIJ1d4vCpcsA0J6dJAAEhCQABIQkAASEJAAEhCQABDp+J1kUhaOvQ6wsy6KJ39e6G25NrDtrbrh1WnN2kgAQEJIAEBCSABAQkgAQEJIAEBCSABAQkgAQEJIAEHDpMvShiYmJrHZ+fl4ZHxwcZM+sra3d25xgENlJAkBASAJAQEgCQEBIAkCgKMu4+b3O+MPNLSC96+zsLKvNzc1VxpeXl9kzU1NT9zanr8UtIHSbW0AA4AsISQAICEkACGgmAH3o9PQ0q83OzjYwExhsdpIAEBCSABAQkgAQEJIAEHBw547qty/Ub15IKb99wc0LfG17e3tZbXFxsYGZwGCzkwSAgJAEgICQBICAd5J39OHDh8p4bGwse+bx48fdmg5DamZmJqsVRSP96GGg2UkCQEBIAkBASAJAQEgCQKAoy/hCbrd153766afK+E9/+lP2zH/+85/KuB9ug2+niRviU7LuvtTHjx9vfaYf1mIT626Y1ly9IUpKmqJ0WnN2kgAQEJIAEBCSABAQkgAQ0HHnjuq3L7h5gSbMz89ntdHR6l/n6+vrbk2HPlLvGpaSzmGd2EkCQEBIAkBASAJAwDvJO6rfvuDmBZqwsbGR1UZGRirjdmtzYWGhMn7//v3XnRg97/T0NKvNzs42MJP+YCcJAAEhCQABIQkAASEJAAG3gPxGg3LzQjtuAekv9bU4OTmZPfPXv/61Mv7HP/5xr3P6Em4BuV/Pnz/Paq9evcpqnz59qoz79d+xz+EWEAD4AkISAAJCEgACmgncUb2xdL2pdEoaS9OM+k3yq6urDc2EXlZviJKSpiid2EkCQEBIAkBASAJAQEgCQMDBnTuq375Qv3khpfwleP3mhZTcvsDXV1937Q6QHR0ddWs69KiXL19mte3t7QZm0h/sJAEgICQBICAkASDgneQd7ezsVMZbW1vZM/XG0k+fPs2e8U6S32J8fDyrPXnypDJu97780aNHlfG7d+++7sToefWGKClpitKJnSQABIQkAASEJAAEhCQABBzc+Y3qNy+k5PYF7t/S0lJWe/DgQWV8eXmZPfP69ev7mhJ9ot4QJSVNUTqxkwSAgJAEgICQBICAkASAgIM7v1H95XZKeacKNy/wta2vr2e1+sGdq6urbk2HPlLvGpaSzmGd2EkCQEBIAkBASAJAwDvJO6rfvlC/eSGl/MPc+s0LKbl9gd/m+Pi46SkwQDRFidlJAkBASAJAQEgCQEBIAkDAwZ07qt++UP+AO6X89gU3L/C1vX37Nqu9ePGigZkwCDRFidlJAkBASAJAQEgCQMA7yTuqN5Zu905SY2nu29jYWFY7PDysjKempro1HfpIvSFKSpqidGInCQABIQkAASEJAAEhCQABB3fuyO0L9ILd3d2stry8XBlPT093azr0kXpDlJQ0RenEThIAAkISAAJCEgAC3kneUb2xtKbSNKFdw4rNzc3K+OLiolvToY/UG6KkpClKJ3aSABAQkgAQEJIAEBCSABBwcOeO6rcv1G9eSMntC9y/lZWVrLa/v9/ATOg3GqLcjZ0kAASEJAAEhCQABIQkAASKsizjHxZF/MMh9ebNm8r42bNn2TP12xf6tfNJWZZFE7+vdTfcmlh3w7TmHj58mNVOTk6y2s3NTWU8yAcSO605O0kACAhJAAgISQAIaCZwR/XO+PWbF1Lq33eQwOCrN0RJSVOUTuwkASAgJAEgICQBICAkASDg4M4d1W9fcPMC0E92d3ez2vLyclarN0UZVnaSABAQkgAQEJIAEPBO8o7aNQcG6Bf1higpaYrSiZ0kAASEJAAEhCQABIQkAAQc3AEYIvWGKClpitKJnSQABIQkAASEJAAEirIs4x8WRfxDBl5ZlkUTv691N9yaWHfW3HDrtObsJAEgICQBICAkASAgJAEg0PHgDgAMMztJAAgISQAICEkACAhJAAgISQAICEkACAhJAAgISQAICEkACAhJAAgIyTaKovhLURT/Lorif0VR/LPp+TD4rDmaYN3dbrTpCfSos5TS31NKayml3zc8F4aDNUcTrLtbCMk2yrL8V0opFUXxx5TSHxqeDkPAmqMJ1t3t/HcrAASEJAAEhCQABIQkAAQc3GmjKIrR9OufzUhKaaQoit+llK7LsrxudmYMKmuOJlh3t7OTbG87pfTflNLfUkp/bv16u9EZMeisOZpg3d2iKMuy6TkAQE+ykwSAgJAEgICQBICAkASAQMdPQIqicKpniJVlWTTx+1p3w62JdWfNDbdOa85OEgACQhIAAkISAAJCEgACQhIAAkISAAJCEgACQhIAAu6TbJmYmMhq5+fnWe3g4KAyXltbu7c5AdAsO0kACAhJAAgISQAICEkACBRlGTe/H6bO+GdnZ1ltbm4uq11eXlbGU1NT9zanprkFhCa4BYRucwsIAHwBIQkAASEJAAHNBFpOT0+z2uzsbAMzYZDVm1Z8TsOKlDStgKbYSQJAQEgCQEBIAkBASAJAwMGdlr29vay2uLjYwEwYZB8+fKiMx8bGsmceP37crekAt7CTBICAkASAgJAEgIB3ki0zMzNZrSga6e/NAKs3rdCwAnqbnSQABIQkAASEJAAEhCQABIqyjC/kHvbbuj9+/HjrM1NTU12YSTOauCE+pcFed8+fP6+MX716lT3z6dOnrDbI66yuiXU3yGuO23Vac3aSABAQkgAQEJIAEBCSABDQcadlfn4+q42O5n8819fX3ZgOA6re2UlXJ/rFxMREVjs/P6+MDw4OsmfW1tbubU7dYCcJAAEhCQABIQkAAc0EWn744Yes9v3332e1X375pTL+5ptvsmfev3//9SbWIM0E7t/nNKxISTOB+zZMa+5LnZ2dZbW5ubnK+PLyMnumH9auZgIA8AWEJAAEhCQABIQkAAQ0E2jZ2dnJaltbW1ltcnKyMn769Gn2zKAc3OHrqzet0LCCfnF6eprVZmdnG5hJd9lJAkBASAJAQEgCQMA7yQ7aNetdXV1tYCYMio2Njcp4ZGQke6Zd0/OFhYXK2Htvum1vby+rLS4uNjCT7rKTBICAkASAgJAEgICQBICAgzsdtDtAUf/Q++joqFvTYQDUm1Z8TsOKlPKmFQ7u0G0zMzNZrd2/kYPGThIAAkISAAJCEgAC3km2jI+PZ7UnT55ktfrH348ePcqeeffu3debGANNwwr6xcuXL7Pa9vZ2AzPpLjtJAAgISQAICEkACAhJAAg4uNOytLSU1R48eJDVLi8vK+PXr1/f15QYAp/TsCIlTSto3vz8fFYbHa1GSLu12+/sJAEgICQBICAkASAgJAEg4OBOy/r6elZrd3Dn6uqqG9NhQNU7O31OV6eU8s5OujrRbRsbG1mtvlbbHURbWFiojPvtBhs7SQAICEkACAhJAAh4J9lyfHzc9BQYAvWmFZ/TsCIlTSto3s7OTlbb2tqqjCcnJ7Nnnj59Whl7JwkAA0JIAkBASAJAQEgCQMDBnZa3b99mtRcvXjQwEwZZvWmFhhX0s4ODg8p4dXW1oZncHztJAAgISQAICEkACHgn2TI2NpbVDg8Ps9rU1FQ3psOA0rSCQVJvaH59fZ09c3R01K3p3As7SQAICEkACAhJAAgISQAIFGVZxj8siviHA+bNmzdZ7dmzZ1ltenq6Mr64uLi3OTWtLMv8mvEuGOR19/Dhw8r45OQke+bm5iarDdOBsSbW3SCvua9lfHw8q/3888+VcbsDkN99911l3Is32nRac3aSABAQkgAQEJIAENBMoKVdU+nNzc2sNsjvILl/9Xc2GlbQL5aWlrJavUH/5eVl9kwvvoO8CztJAAgISQAICEkACAhJAAg4uNOysrKS1fb39xuYCYNsd3e3Ml5eXs6eqTesgF6wvr6e1eoHd9odgOx3dpIAEBCSABAQkgAQEJIAEHBwp6V+OwPch/rBBl2d6BfHx8dNT6ERdpIAEBCSABAQkgAQKMoyvpDbbd3DrYkb4lMa7HV3dnZWGX/77bfZMz/++GOXZtObmlh3g7zmvpZ25zZOTk4q45ubm+yZfrjVptOas5MEgICQBICAkASAgJAEgIBmAtBFmlbQr8bGxrLa4eFhZdwPh3Tuyk4SAAJCEgACQhIAAt5JAnCr3d3drLa8vFwZT09Pd2s6XWMnCQABIQkAASEJAAEhCQABB3cAuNXV1VVW29zcrIwvLi66NZ2usZMEgICQBICAkASAgHeSANxqZWUlq+3v7zcwk+6ykwSAgJAEgICQBICAkASAQFGWZdNzAICeZCcJAAEhCQABIQkAASEJAAEhCQABIQkAASEJAAEhCQABIQkAASEJAAEh2UZRFH8piuLfRVH8ryiKfzY9HwafNUcTrLvbuXS5vbOU0t9TSmsppd83PBeGgzVHE6y7WwjJNsqy/FdKKRVF8ceU0h8ang5DwJqjCdbd7fx3KwAEhCQABIQkAASEJAAEHNxpoyiK0fTrn81ISmmkKIrfpZSuy7K8bnZmDCprjiZYd7ezk2xvO6X035TS31JKf279ervRGTHorDmaYN3doijLsuk5AEBPspMEgICQBICAkASAgJAEgICQBIBAx+8ki6Jw9HWIlWVZNPH7WnfDrYl1Z80Nt05rzk4SAAJCEgACQhIAAkISAAJCEgACQhIAAkISAAJCEgACLl2GPjAxMVEZn5+fZ88cHBxUxmtra/c6JxgGdpIAEBCSABAQkgAQEJIAECjKMm5+rzP+cHMLSO84OzurjOfm5rJnLi8vK+Opqal7ndN9cQsI3eYWEAD4AkISAAJCEgACmgm01D/WTskH2/SO09PTynh2drahmcBwsZMEgICQBICAkASAgJAEgICDOy0fPnzIamNjY1nt8ePH3ZgOVOzt7VXGi4uLDc0EhoudJAAEhCQABIQkAAS8k2ypf6ydkg+26R0zMzOVcVE00nseho6dJAAEhCQABIQkAASEJAAEirKML+Qeptu6nz9/ntVevXqV1T59+lQZ9+vt75+jiRviUxqudfelPn78eOsz/bo2m1h31txw67Tm7CQBICAkASAgJAEgICQBIKDjTku9o0lKuprQO+bn5yvj0dH8r+719XW3pkMfm5iYyGrn5+dZ7eDgoDJeW1u7tzn1MjtJAAgISQAICEkACGgm0MEgf7D9OTQT6B0//PBDZfz9999nz/zyyy+V8TfffJM98/79+687sXugmcD9Ojs7y2pzc3NZ7fLysjIe1n/r7CQBICAkASAgJAEgICQBIKCZQEv9Y+2UfLBN79jZ2amMt7a2smcmJycr46dPn2bP9MPBHe7X6elpVpudnW1gJv3BThIAAkISAAJCEgAC3km2bGxsZLWRkZGsVm96vrCwkD3jvQ/3rd58OqWUVldXG5gJ/WZvby+rLS4uNjCT/mAnCQABIQkAASEJAAEhCQABB3da6h9rp+SDbXpX/QBZSnmji6Ojo25Nhz4yMzOT1dqtJ35lJwkAASEJAAEhCQAB7yQ78ME2vWJ8fLwyfvLkSfZMvfnFo0ePsmfevXv3dSdG33n58mVW297ebmAm/cFOEgACQhIAAkISAAJCEgACDu504INtesXS0lJl/ODBg+yZy8vLyvj169f3OSX61Pz8fFYbHc2joP5v3bCykwSAgJAEgICQBICAkASAgIM7LfWOJinpakLvWF9fr4zbHdy5urrq1nToYxsbG1mt/u9aSvnBxYWFheyZYbjxyE4SAAJCEgACQhIAAt5JttQ/1k7JB9v0juPj46anwIDY2dnJaltbW1ltcnKyMn769Gn2jHeSADDEhCQABIQkAASEJAAEHNxpqX+snZIPtukdb9++rYxfvHjR0EwYRAcHB1ltdXW1gZn0HjtJAAgISQAICEkACHgn2eJjbXrZ2NhYZXx4eJg9MzU11a3pMGDqzcxTSun6+royPjo66tZ0eoqdJAAEhCQABIQkAASEJAAEirIs4x8WRfzDAfPw4cOsdnJyktVubm4q40E+LFGWZf42vwuGad19rjdv3lTGz549y56Znp6ujC8uLu51TveliXU3TGtufHw8q/38889ZrX5Y7LvvvsueGZRbkDqtOTtJAAgISQAICEkACGgm0FL///eUfLBN76g31t/c3Mye6dd3kHTX0tJSVmt3mcPl5WVlPCjvH+/KThIAAkISAAJCEgACQhIAAg7utOzu7ma15eXlrFb/YBu6YWVlpTLe399vaCb0u/X19azW7uBO/bDYsLKTBICAkASAgJAEgICQBICAgzst7V5S62pCr2h3Sw18iePj46an0FfsJAEgICQBICAkASBQlGV8Ifcw3dZ9dnaW1b799tus9uOPP3ZhNr2hiRviUxqudUeuiXU3TGuu3fvtk5OTrHZzc1MZD/INSJ3WnJ0kAASEJAAEhCQABIQkAAQ0E2jxsTYwDMbGxrLa4eFhVhvkgzp3YScJAAEhCQABIQkAAe8kAYbI7u5uVlteXs5q09PT3ZhOz7OTBICAkASAgJAEgICQBICAgzsAQ+Tq6iqrbW5uZrWLi4tuTKfn2UkCQEBIAkBASAJAwDtJgCGysrKS1fb39xuYSX+wkwSAgJAEgICQBICAkASAQFGWZdNzAICeZCcJAAEhCQABIQkAASEJAAEhCQABIQkAgf8HkYyeGOpGsyQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dbunch1.show_batch(figsize=(8,8), cmap='gray')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "dls.show_batch(figsize=(5,5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([128, 1, 28, 28]), torch.Size([128]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb,yb = dls.one_batch()\n", "xb.shape,yb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic CNN with batchnorm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv(ni,nf): return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv(1, 8), # 14\n", " nn.BatchNorm2d(8),\n", " nn.ReLU(),\n", " conv(8, 16), # 7\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " conv(16, 32), # 4\n", " nn.BatchNorm2d(32),\n", " nn.ReLU(),\n", " conv(32, 16), # 2\n", " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " conv(16, 10), # 1\n", " nn.BatchNorm2d(10),\n", " Flatten() # remove (1,1) grid\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential (Input shape: 128 x 1 x 28 x 28)\n", "================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "================================================================\n", "Conv2d 128 x 8 x 14 x 14 80 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 8 x 14 x 14 16 True \n", "________________________________________________________________\n", "ReLU 128 x 8 x 14 x 14 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 7 x 7 1,168 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 7 x 7 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 7 x 7 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 32 x 4 x 4 4,640 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 32 x 4 x 4 64 True \n", "________________________________________________________________\n", "ReLU 128 x 32 x 4 x 4 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 2 x 2 4,624 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 2 x 2 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 2 x 2 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 10 x 1 x 1 1,450 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 10 x 1 x 1 20 True \n", "________________________________________________________________\n", "Flatten 128 x 10 0 False \n", "________________________________________________________________\n", "\n", "Total params: 12,126\n", "Total trainable params: 12,126\n", "Total non-trainable params: 0\n", "\n", "Optimizer used: \n", "Loss function: CrossEntropyLoss()\n", "\n", "Callbacks:\n", " - TrainEvalCallback\n", " - Recorder\n", " - ProgressCallback\n" ] } ], "source": [ "print(learn.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xb = xb.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 10])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(xb).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find(end_lr=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.2138120.3546800.89220000:29
10.1230050.1145180.96390000:14
20.0804210.0431240.98610000:13
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, lr_max=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Refactor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv2(ni,nf): return ConvLayer(ni,nf,stride=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv2(1, 8), # 14\n", " conv2(8, 16), # 7\n", " conv2(16, 32), # 4\n", " conv2(32, 16), # 2\n", " conv2(16, 10), # 1\n", " Flatten() # remove (1,1) grid\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.2233110.1931030.93870000:12
10.1771940.1456100.95130000:12
20.1477100.1198390.96160000:13
30.1217550.0932310.97030000:12
40.1128890.0734950.97560000:12
50.0942160.0652820.97930000:12
60.0787940.0481460.98560000:12
70.0582870.0324820.98930000:12
80.0461360.0259680.99160000:12
90.0410370.0264250.99070000:13
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, lr_max=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Resnet-ish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ResBlock(Module):\n", " def __init__(self, nf):\n", " self.conv1 = ConvLayer(nf,nf)\n", " self.conv2 = ConvLayer(nf,nf)\n", " \n", " def forward(self, x): return x + self.conv2(self.conv1(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv2(1, 8),\n", " ResBlock(8),\n", " conv2(8, 16),\n", " ResBlock(16),\n", " conv2(16, 32),\n", " ResBlock(32),\n", " conv2(32, 16),\n", " ResBlock(16),\n", " conv2(16, 10),\n", " Flatten()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv_and_res(ni,nf): return nn.Sequential(conv2(ni, nf), ResBlock(nf))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " conv_and_res(1, 8),\n", " conv_and_res(8, 16),\n", " conv_and_res(16, 32),\n", " conv_and_res(32, 16),\n", " conv2(16, 10),\n", " Flatten()\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, model, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find(end_lr=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1834730.4921550.85410000:18
10.1337310.1881460.94500000:17
20.0891430.0806000.97410000:18
30.0845150.0524630.98170000:18
40.0619130.0448650.98710000:18
50.0570880.0597050.98140000:18
60.0548640.0345800.98860000:18
70.0359860.0314460.98990000:17
80.0325350.0222800.99280000:19
90.0263290.0186590.99430000:17
100.0203020.0165710.99450000:18
110.0201050.0162520.99530000:17
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(12, lr_max=0.05)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential (Input shape: 128 x 1 x 28 x 28)\n", "================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "================================================================\n", "Conv2d 128 x 8 x 14 x 14 72 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 8 x 14 x 14 16 True \n", "________________________________________________________________\n", "ReLU 128 x 8 x 14 x 14 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 8 x 14 x 14 576 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 8 x 14 x 14 16 True \n", "________________________________________________________________\n", "ReLU 128 x 8 x 14 x 14 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 8 x 14 x 14 576 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 8 x 14 x 14 16 True \n", "________________________________________________________________\n", "ReLU 128 x 8 x 14 x 14 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 7 x 7 1,152 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 7 x 7 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 7 x 7 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 7 x 7 2,304 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 7 x 7 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 7 x 7 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 7 x 7 2,304 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 7 x 7 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 7 x 7 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 32 x 4 x 4 4,608 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 32 x 4 x 4 64 True \n", "________________________________________________________________\n", "ReLU 128 x 32 x 4 x 4 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 32 x 4 x 4 9,216 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 32 x 4 x 4 64 True \n", "________________________________________________________________\n", "ReLU 128 x 32 x 4 x 4 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 32 x 4 x 4 9,216 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 32 x 4 x 4 64 True \n", "________________________________________________________________\n", "ReLU 128 x 32 x 4 x 4 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 2 x 2 4,608 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 2 x 2 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 2 x 2 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 2 x 2 2,304 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 2 x 2 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 2 x 2 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 16 x 2 x 2 2,304 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 16 x 2 x 2 32 True \n", "________________________________________________________________\n", "ReLU 128 x 16 x 2 x 2 0 False \n", "________________________________________________________________\n", "Conv2d 128 x 10 x 1 x 1 1,440 True \n", "________________________________________________________________\n", "BatchNorm2d 128 x 10 x 1 x 1 20 True \n", "________________________________________________________________\n", "ReLU 128 x 10 x 1 x 1 0 False \n", "________________________________________________________________\n", "Flatten 128 x 10 0 False \n", "________________________________________________________________\n", "\n", "Total params: 41,132\n", "Total trainable params: 41,132\n", "Total non-trainable params: 0\n", "\n", "Optimizer used: \n", "Loss function: CrossEntropyLoss()\n", "\n", "Model unfrozen\n", "\n", "Callbacks:\n", " - TrainEvalCallback\n", " - Recorder\n", " - ProgressCallback\n" ] } ], "source": [ "print(learn.summary())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 1 }