{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## CIFAR 10" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2\n", "\n", "from fastai.conv_learner import *\n", "PATH = 'data/cifar/'\n", "os.makedirs(PATH, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get the data via:\n", "\n", " wget http://pjreddie.com/media/files/cifar.tgz" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def to_label_subdirs(path, subdirs, classes, labelfn):\n", " for sd in subdirs:\n", " for rf in os.listdir(os.path.join(path, sd)):\n", " af = os.path.join(path, sd, rf)\n", " if not os.path.isfile(af):\n", " continue\n", " lb = labelfn(rf)\n", " if not lb:\n", " continue\n", " os.renames(af, os.path.join(path, sd, lb, rf))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "to_label_subdirs(PATH, 'train test'.split(), classes, lambda f: f[f.find('_') + 1 : f.find('.')])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def get_data(sz,bs):\n", " tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz // 8)\n", " return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "bs=256" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Look at data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "data = get_data(32, 4)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(data.trn_dl))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(data.trn_ds.denorm(x)[0])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(data.trn_ds.denorm(x)[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fully connected model" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "data = get_data(32, bs)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "lr = 1e-2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From [this notebook](https://github.com/KeremTurgutlu/deeplearning/blob/master/Exploring%20Optimizers.ipynb) by our student Kerem Turgutlu:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class SimpleNet(nn.Module):\n", " def __init__(self, layers):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)])\n", " \n", " def forward(self, x):\n", " x = x.view(x.size(0), -1)\n", " for l in self.layers:\n", " l_x = l(x)\n", " x = F.relu(l_x)\n", " return F.log_softmax(l_x, dim=-1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(SimpleNet([32*32*3, 40, 10]), data)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(SimpleNet(\n", " (layers): ModuleList(\n", " (0): Linear(in_features=3072, out_features=40)\n", " (1): Linear(in_features=40, out_features=10)\n", " )\n", " ), [122880, 40, 400, 10])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn, [o.numel() for o in learn.model.parameters()]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('Linear-1',\n", " OrderedDict([('input_shape', [-1, 3072]),\n", " ('output_shape', [-1, 40]),\n", " ('trainable', True),\n", " ('nb_params', 122920)])),\n", " ('Linear-2',\n", " OrderedDict([('input_shape', [-1, 40]),\n", " ('output_shape', [-1, 10]),\n", " ('trainable', True),\n", " ('nb_params', 410)]))])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.summary()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "61a8c98dab2042049de01ea0db471779", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 76%|███████▌ | 148/196 [00:15<00:05, 9.40it/s, loss=10] " ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 76%|███████▌ | 148/196 [00:30<00:09, 4.93it/s, loss=10]" ] } ], "source": [ "learn.sched.plot()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dcbf4554bf6345479a904afdc84eb02b", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 13%|█▎ | 26/196 [00:03<00:20, 8.28it/s, loss=2.06]\n", " 13%|█▎ | 26/196 [00:03<00:20, 8.22it/s, loss=2.05]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Exception in thread Thread-4:\n", "Traceback (most recent call last):\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n", " self.run()\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n", " for instance in self.tqdm_cls._instances:\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n", " for itemref in self.data:\n", "RuntimeError: Set changed size during iteration\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.766196 1.642816 0.419 \n", " 1 1.675517 1.568509 0.4466 \n", "\n", "CPU times: user 1min 38s, sys: 2min 51s, total: 4min 29s\n", "Wall time: 44.3 s\n" ] }, { "data": { "text/plain": [ "[array([1.56851]), 0.4466]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(lr, 2)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8717af90861d4c74af82cc5b87af1457", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.617357 1.515796 0.4654 \n", " 1 1.582096 1.496592 0.4684 \n", "\n", "CPU times: user 1min 37s, sys: 2min 46s, total: 4min 23s\n", "Wall time: 43.5 s\n" ] }, { "data": { "text/plain": [ "[array([1.49659]), 0.4684]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(lr, 2, cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CNN" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class ConvNet(nn.Module):\n", " def __init__(self, layers, c):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " nn.Conv2d(layers[i], layers[i + 1], kernel_size=3, stride=2)\n", " for i in range(len(layers) - 1)])\n", " self.pool = nn.AdaptiveMaxPool2d(1)\n", " self.out = nn.Linear(layers[-1], c)\n", " \n", " def forward(self, x):\n", " for l in self.layers: x = F.relu(l(x))\n", " x = self.pool(x)\n", " x = x.view(x.size(0), -1)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(ConvNet([3, 20, 40, 80], 10), data)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('Conv2d-1',\n", " OrderedDict([('input_shape', [-1, 3, 32, 32]),\n", " ('output_shape', [-1, 20, 15, 15]),\n", " ('trainable', True),\n", " ('nb_params', 560)])),\n", " ('Conv2d-2',\n", " OrderedDict([('input_shape', [-1, 20, 15, 15]),\n", " ('output_shape', [-1, 40, 7, 7]),\n", " ('trainable', True),\n", " ('nb_params', 7240)])),\n", " ('Conv2d-3',\n", " OrderedDict([('input_shape', [-1, 40, 7, 7]),\n", " ('output_shape', [-1, 80, 3, 3]),\n", " ('trainable', True),\n", " ('nb_params', 28880)])),\n", " ('AdaptiveMaxPool2d-4',\n", " OrderedDict([('input_shape', [-1, 80, 3, 3]),\n", " ('output_shape', [-1, 80, 1, 1]),\n", " ('nb_params', 0)])),\n", " ('Linear-5',\n", " OrderedDict([('input_shape', [-1, 80]),\n", " ('output_shape', [-1, 10]),\n", " ('trainable', True),\n", " ('nb_params', 810)]))])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.summary()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "95db0c19899d459da8c873cb990dc061", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 98%|█████████▊| 192/196 [00:18<00:00, 10.29it/s, loss=10.1]" ] } ], "source": [ "learn.lr_find(end_lr=100)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r", " 98%|█████████▊| 192/196 [00:30<00:00, 6.39it/s, loss=10.1]" ] } ], "source": [ "learn.sched.plot()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "419db733bd1544da848214f87d983030", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 15%|█▍ | 29/196 [00:03<00:18, 9.18it/s, loss=2.21] \n", " 16%|█▋ | 32/196 [00:03<00:17, 9.50it/s, loss=2.2] " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Exception in thread Thread-10:\n", "Traceback (most recent call last):\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n", " self.run()\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/site-packages/tqdm/_tqdm.py\", line 144, in run\n", " for instance in self.tqdm_cls._instances:\n", " File \"/home/paperspace/anaconda3/envs/fastai/lib/python3.6/_weakrefset.py\", line 60, in __iter__\n", " for itemref in self.data:\n", "RuntimeError: Set changed size during iteration\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.711504 1.737088 0.3824 \n", " 1 1.52142 1.558574 0.4381 \n", "\n", "CPU times: user 1min 38s, sys: 2min 51s, total: 4min 29s\n", "Wall time: 44 s\n" ] }, { "data": { "text/plain": [ "[array([1.55857]), 0.4381]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-1, 2)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "800ec2126d1847a1be3ce641f495f068", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.456638 1.38682 0.5034 \n", " 1 1.357452 1.284294 0.5388 \n", " 2 1.296569 1.239791 0.5547 \n", " 3 1.264639 1.205657 0.5701 \n", "\n", "CPU times: user 3min 21s, sys: 5min 45s, total: 9min 7s\n", "Wall time: 1min 27s\n" ] }, { "data": { "text/plain": [ "[array([1.20566]), 0.5701]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-1, 4, cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Refactored" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "class ConvLayer(nn.Module):\n", " def __init__(self, ni, nf):\n", " super().__init__()\n", " self.conv = nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)\n", " \n", " def forward(self, x): return F.relu(self.conv(x))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "class ConvNet2(nn.Module):\n", " def __init__(self, layers, c):\n", " super().__init__()\n", " self.layers = nn.ModuleList([ConvLayer(layers[i], layers[i + 1])\n", " for i in range(len(layers) - 1)])\n", " self.out = nn.Linear(layers[-1], c)\n", " \n", " def forward(self, x):\n", " for l in self.layers: x = l(x)\n", " x = F.adaptive_max_pool2d(x, 1)\n", " x = x.view(x.size(0), -1)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(ConvNet2([3, 20, 40, 80], 10), data)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('Conv2d-1',\n", " OrderedDict([('input_shape', [-1, 3, 32, 32]),\n", " ('output_shape', [-1, 20, 16, 16]),\n", " ('trainable', True),\n", " ('nb_params', 560)])),\n", " ('ConvLayer-2',\n", " OrderedDict([('input_shape', [-1, 3, 32, 32]),\n", " ('output_shape', [-1, 20, 16, 16]),\n", " ('nb_params', 0)])),\n", " ('Conv2d-3',\n", " OrderedDict([('input_shape', [-1, 20, 16, 16]),\n", " ('output_shape', [-1, 40, 8, 8]),\n", " ('trainable', True),\n", " ('nb_params', 7240)])),\n", " ('ConvLayer-4',\n", " OrderedDict([('input_shape', [-1, 20, 16, 16]),\n", " ('output_shape', [-1, 40, 8, 8]),\n", " ('nb_params', 0)])),\n", " ('Conv2d-5',\n", " OrderedDict([('input_shape', [-1, 40, 8, 8]),\n", " ('output_shape', [-1, 80, 4, 4]),\n", " ('trainable', True),\n", " ('nb_params', 28880)])),\n", " ('ConvLayer-6',\n", " OrderedDict([('input_shape', [-1, 40, 8, 8]),\n", " ('output_shape', [-1, 80, 4, 4]),\n", " ('nb_params', 0)])),\n", " ('Linear-7',\n", " OrderedDict([('input_shape', [-1, 80]),\n", " ('output_shape', [-1, 10]),\n", " ('trainable', True),\n", " ('nb_params', 810)]))])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.summary()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fe4104c61a234e75a4bd64b53fe39081", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.728346 1.639025 0.4117 \n", " 1 1.513297 1.399134 0.4903 \n", "\n", "CPU times: user 1min 42s, sys: 2min 48s, total: 4min 31s\n", "Wall time: 44.2 s\n" ] }, { "data": { "text/plain": [ "[array([1.39913]), 0.4903]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-1, 2)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a4f3f7246fb54ba08a5e9702becaa95b", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.322032 1.260205 0.5485 \n", " 1 1.274674 1.20203 0.5723 \n", "\n", "CPU times: user 1min 38s, sys: 2min 52s, total: 4min 30s\n", "Wall time: 44.3 s\n" ] }, { "data": { "text/plain": [ "[array([1.20203]), 0.5723]" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-1, 2, cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BatchNorm" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "class BnLayer(nn.Module):\n", " def __init__(self, ni, nf, stride=2, kernel_size=3):\n", " super().__init__()\n", " self.conv = nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride,\n", " bias=False, padding=1)\n", " self.a = nn.Parameter(torch.zeros(nf, 1, 1))\n", " self.m = nn.Parameter(torch.ones(nf, 1, 1))\n", " \n", " def forward(self, x):\n", " x = F.relu(self.conv(x))\n", " x_chan = x.transpose(0, 1).contiguous().view(x.size(1), -1)\n", " if self.training:\n", " self.means = x_chan.mean(1)[:, None, None]\n", " self.stds = x_chan.std (1)[:, None, None]\n", " return (x-self.means) / self.stds * self.m + self.a" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "class ConvBnNet(nn.Module):\n", " def __init__(self, layers, c):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(3, 10, kernel_size=5, stride=1, padding=2)\n", " self.layers = nn.ModuleList([BnLayer(layers[i], layers[i + 1])\n", " for i in range(len(layers) - 1)])\n", " self.out = nn.Linear(layers[-1], c)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " for l in self.layers: x = l(x)\n", " x = F.adaptive_max_pool2d(x, 1)\n", " x = x.view(x.size(0), -1)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(ConvBnNet([10, 20, 40, 80, 160], 10), data)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict([('Conv2d-1',\n", " OrderedDict([('input_shape', [-1, 3, 32, 32]),\n", " ('output_shape', [-1, 10, 32, 32]),\n", " ('trainable', True),\n", " ('nb_params', 760)])),\n", " ('Conv2d-2',\n", " OrderedDict([('input_shape', [-1, 10, 32, 32]),\n", " ('output_shape', [-1, 20, 16, 16]),\n", " ('trainable', True),\n", " ('nb_params', 1800)])),\n", " ('BnLayer-3',\n", " OrderedDict([('input_shape', [-1, 10, 32, 32]),\n", " ('output_shape', [-1, 20, 16, 16]),\n", " ('nb_params', 0)])),\n", " ('Conv2d-4',\n", " OrderedDict([('input_shape', [-1, 20, 16, 16]),\n", " ('output_shape', [-1, 40, 8, 8]),\n", " ('trainable', True),\n", " ('nb_params', 7200)])),\n", " ('BnLayer-5',\n", " OrderedDict([('input_shape', [-1, 20, 16, 16]),\n", " ('output_shape', [-1, 40, 8, 8]),\n", " ('nb_params', 0)])),\n", " ('Conv2d-6',\n", " OrderedDict([('input_shape', [-1, 40, 8, 8]),\n", " ('output_shape', [-1, 80, 4, 4]),\n", " ('trainable', True),\n", " ('nb_params', 28800)])),\n", " ('BnLayer-7',\n", " OrderedDict([('input_shape', [-1, 40, 8, 8]),\n", " ('output_shape', [-1, 80, 4, 4]),\n", " ('nb_params', 0)])),\n", " ('Conv2d-8',\n", " OrderedDict([('input_shape', [-1, 80, 4, 4]),\n", " ('output_shape', [-1, 160, 2, 2]),\n", " ('trainable', True),\n", " ('nb_params', 115200)])),\n", " ('BnLayer-9',\n", " OrderedDict([('input_shape', [-1, 80, 4, 4]),\n", " ('output_shape', [-1, 160, 2, 2]),\n", " ('nb_params', 0)])),\n", " ('Linear-10',\n", " OrderedDict([('input_shape', [-1, 160]),\n", " ('output_shape', [-1, 10]),\n", " ('trainable', True),\n", " ('nb_params', 1610)]))])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.summary()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1f1acb285f2c47aba9a6a1cd688a557d", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.474475 1.421361 0.4984 \n", " 1 1.264842 1.144034 0.5881 \n", "\n", "CPU times: user 1min 56s, sys: 3min 24s, total: 5min 21s\n", "Wall time: 47.6 s\n" ] }, { "data": { "text/plain": [ "[array([1.14403]), 0.5881]" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(3e-2, 2)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "93b61cc1195446a0ba7e55e8aeeac4e3", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.168034 1.030074 0.6267 \n", " 1 1.030772 0.96697 0.6655 \n", " 2 0.964813 0.872289 0.696 \n", " 3 0.905667 0.837793 0.7079 \n", "\n", "CPU times: user 3min 53s, sys: 6min 43s, total: 10min 36s\n", "Wall time: 1min 34s\n" ] }, { "data": { "text/plain": [ "[array([0.83779]), 0.7079]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-1, 4, cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deep BatchNorm" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "class ConvBnNet2(nn.Module):\n", " def __init__(self, layers, c):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(3, 10, kernel_size=5, stride=1, padding=2)\n", " self.layers = nn.ModuleList([BnLayer(layers[i], layers[i + 1])\n", " for i in range(len(layers) - 1)])\n", " self.layers2 = nn.ModuleList([BnLayer(layers[i + 1], layers[i + 1], 1)\n", " for i in range(len(layers) - 1)])\n", " self.out = nn.Linear(layers[-1], c)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " for l,l2 in zip(self.layers, self.layers2):\n", " x = l(x)\n", " x = l2(x)\n", " x = F.adaptive_max_pool2d(x, 1)\n", " x = x.view(x.size(0), -1)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(ConvBnNet2([10, 20, 40, 80, 160], 10), data)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1d32e84eb0af4e2ea27310bb794b7c6a", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.505403 1.369886 0.4972 \n", " 1 1.292517 1.193988 0.5743 \n", "\n", "CPU times: user 2min 7s, sys: 3min 40s, total: 5min 47s\n", "Wall time: 50.9 s\n" ] }, { "data": { "text/plain": [ "[array([1.19399]), 0.5743]" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 2)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2305d35f716c481cbe358bda04d3604d", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.114137 1.040168 0.6291 \n", " 1 1.034688 0.982892 0.6514 \n", "\n", "CPU times: user 2min 5s, sys: 3min 42s, total: 5min 47s\n", "Wall time: 51.3 s\n" ] }, { "data": { "text/plain": [ "[array([0.98289]), 0.6514]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 2, cycle_len=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Resnet" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "class ResnetLayer(BnLayer):\n", " def forward(self, x): return x + super().forward(x)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "class Resnet(nn.Module):\n", " def __init__(self, layers, c):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(3, 10, kernel_size=5, stride=1, padding=2)\n", " self.layers = nn.ModuleList([BnLayer(layers[i], layers[i + 1])\n", " for i in range(len(layers) - 1)])\n", " self.layers2 = nn.ModuleList([ResnetLayer(layers[i + 1], layers[i + 1], 1)\n", " for i in range(len(layers) - 1)])\n", " self.layers3 = nn.ModuleList([ResnetLayer(layers[i + 1], layers[i + 1], 1)\n", " for i in range(len(layers) - 1)])\n", " self.out = nn.Linear(layers[-1], c)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " for l,l2,l3 in zip(self.layers, self.layers2, self.layers3):\n", " x = l3(l2(l(x)))\n", " x = F.adaptive_max_pool2d(x, 1)\n", " x = x.view(x.size(0), -1)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(Resnet([10, 20, 40, 80, 160], 10), data)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "wd = 1e-5" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa9d90215bb948eb9691c7fbea8020fb", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.555916 1.497878 0.4593 \n", " 1 1.286486 1.179735 0.5811 \n", "\n", "CPU times: user 2min 11s, sys: 3min 50s, total: 6min 2s\n", "Wall time: 56.1 s\n" ] }, { "data": { "text/plain": [ "[array([1.17973]), 0.5811]" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 2, wds=wd)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5814d5dedcb54752bffe08266467f768", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.075466 1.038314 0.6273 \n", " 1 1.047439 0.991257 0.6495 \n", " 2 0.931519 0.911734 0.6783 \n", " 3 0.96682 0.917621 0.6752 \n", " 4 0.860942 0.831846 0.7079 \n", " 5 0.760845 0.758946 0.7312 \n", " 6 0.723117 0.757247 0.7335 \n", "\n", "CPU times: user 7min 40s, sys: 13min 20s, total: 21min\n", "Wall time: 3min 14s\n" ] }, { "data": { "text/plain": [ "[array([0.75725]), 0.7335]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 3, cycle_len=1, cycle_mult=2, wds=wd)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aa1e415a596e4955bd199c07778152c2", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.821457 0.828655 0.7095 \n", " 1 0.735682 0.729878 0.7426 \n", " 2 0.67023 0.709068 0.7555 \n", " 3 0.614033 0.68396 0.7587 \n", " 4 0.72356 0.731519 0.7446 \n", " 5 0.646835 0.666082 0.7637 \n", " 6 0.587808 0.630712 0.7811 \n", " 7 0.538311 0.63509 0.7782 \n", " 8 0.653226 0.719071 0.7544 \n", " 9 0.587378 0.638724 0.779 \n", " 10 0.53147 0.606534 0.791 \n", " 11 0.486855 0.574349 0.8018 \n", " 12 0.60116 0.674546 0.7682 \n", " 13 0.536271 0.590718 0.793 \n", " 14 0.478524 0.577702 0.8039 \n", " 15 0.439396 0.589477 0.7972 \n", "\n", "CPU times: user 17min 51s, sys: 30min 39s, total: 48min 31s\n", "Wall time: 7min 37s\n" ] }, { "data": { "text/plain": [ "[array([0.58948]), 0.7972]" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 4, cycle_len=4, wds=wd)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Resnet 2" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "hidden": true }, "outputs": [], "source": [ "class Resnet2(nn.Module):\n", " def __init__(self, layers, c, p=0.5):\n", " super().__init__()\n", " self.conv1 = BnLayer(3, 16, stride=1, kernel_size=7)\n", " self.layers = nn.ModuleList([BnLayer(layers[i], layers[i + 1])\n", " for i in range(len(layers) - 1)])\n", " self.layers2 = nn.ModuleList([ResnetLayer(layers[i + 1], layers[i + 1], 1)\n", " for i in range(len(layers) - 1)])\n", " self.layers3 = nn.ModuleList([ResnetLayer(layers[i + 1], layers[i + 1], 1)\n", " for i in range(len(layers) - 1)])\n", " self.out = nn.Linear(layers[-1], c)\n", " self.drop = nn.Dropout(p)\n", " \n", " def forward(self, x):\n", " x = self.conv1(x)\n", " for l,l2,l3 in zip(self.layers, self.layers2, self.layers3):\n", " x = l3(l2(l(x)))\n", " x = F.adaptive_max_pool2d(x, 1)\n", " x = x.view(x.size(0), -1)\n", " x = self.drop(x)\n", " return F.log_softmax(self.out(x), dim=-1)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "hidden": true }, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(Resnet2([16, 32, 64, 128, 256], 10, 0.2), data)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "hidden": true }, "outputs": [], "source": [ "wd = 1e-6" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9d68521acef0497fbb9e22097935f14f", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.726595 1.476193 0.477 \n", " 1 1.511903 1.594056 0.5248 \n", "\n", "CPU times: user 2min 33s, sys: 4min 11s, total: 6min 44s\n", "Wall time: 1min 7s\n" ] }, { "data": { "text/plain": [ "[array([1.59406]), 0.5248]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 2, wds=wd)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ec4b9952e70d4b18912edb32a7e0f371", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=7), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 1.246219 1.123039 0.5994 \n", " 1 1.205521 1.079559 0.6165 \n", " 2 1.047397 0.982982 0.6518 \n", " 3 1.11306 1.042084 0.643 \n", " 4 0.986444 0.938702 0.6705 \n", " 5 0.86359 0.827887 0.7107 \n", " 6 0.820836 0.859191 0.6998 \n", "\n", "CPU times: user 8min 53s, sys: 14min 50s, total: 23min 43s\n", "Wall time: 3min 58s\n" ] }, { "data": { "text/plain": [ "[array([0.85919]), 0.6998]" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 3, cycle_len=1, cycle_mult=2, wds=wd)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d855b76ada04385a95bfafe788e6faf", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type HBox.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 0.922289 0.922782 0.6783 \n", " 1 0.843655 0.815614 0.7125 \n", " 2 0.733954 0.732746 0.7458 \n", " 3 0.695225 0.732575 0.7457 \n", " 4 0.826921 0.759739 0.7323 \n", " 5 0.731842 0.704877 0.7553 \n", " 6 0.642404 0.659563 0.7721 \n", " 7 0.605025 0.728076 0.7616 \n", " 8 0.72092 0.719592 0.7534 \n", " 9 0.652721 0.653841 0.776 \n", " 10 0.583139 0.606309 0.7903 \n", " 11 0.535503 0.64212 0.7817 \n", " 12 0.656404 0.654129 0.7783 \n", " 13 0.587746 0.655965 0.7777 \n", " 14 0.518367 0.601116 0.7925 \n", " 15 0.489359 0.597911 0.7945 \n", "\n", "CPU times: user 20min 17s, sys: 33min 52s, total: 54min 9s\n", "Wall time: 8min 58s\n" ] }, { "data": { "text/plain": [ "[array([0.59791]), 0.7945]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time learn.fit(1e-2, 4, cycle_len=4, wds=wd)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "266px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }