{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get everything in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nb_007 import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('../data/cifar10/')\n", "torch.backends.cudnn.benchmark = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model definition" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Lambda(nn.Module):\n", " def __init__(self, func):\n", " super().__init__()\n", " self.func=func\n", " \n", " def forward(self, x): return self.func(x)\n", "\n", "def ResizeBatch(*size): return Lambda(lambda x: x.view((-1,)+size))\n", "def Flatten(): return Lambda(lambda x: x.view((x.size(0), -1)))\n", "def PoolFlatten(): return nn.Sequential(nn.AdaptiveAvgPool2d(1), Flatten())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)\n", "\n", "def bn(ni, init_zero=False):\n", " m = nn.BatchNorm2d(ni)\n", " m.weight.data.fill_(0 if init_zero else 1)\n", " m.bias.data.zero_()\n", " return m\n", "\n", "def bn_relu_conv(ni, nf, ks, stride, init_zero=False):\n", " bn_initzero = bn(ni, init_zero=init_zero)\n", " return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride))\n", "\n", "def noop(x): return x\n", "\n", "class BasicBlock(nn.Module):\n", " def __init__(self, ni, nf, stride, drop_p=0.0):\n", " super().__init__()\n", " self.bn = nn.BatchNorm2d(ni)\n", " self.conv1 = conv_2d(ni, nf, 3, stride)\n", " self.conv2 = bn_relu_conv(nf, nf, 3, 1)\n", " self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None\n", " self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop\n", "\n", " def forward(self, x):\n", " x2 = F.relu(self.bn(x), inplace=True)\n", " r = self.shortcut(x2)\n", " x = self.conv1(x2)\n", " if self.drop: x = self.drop(x)\n", " x = self.conv2(x) * 0.2\n", " return x.add_(r)\n", "\n", "\n", "def _make_group(N, ni, nf, block, stride, drop_p):\n", " return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]\n", "\n", "class WideResNet(nn.Module):\n", " def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16):\n", " super().__init__()\n", " n_channels = [start_nf]\n", " for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)\n", "\n", " layers = [conv_2d(3, n_channels[0], 3, 1)] # conv1\n", " for i in range(num_groups):\n", " layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)\n", "\n", " layers += [nn.BatchNorm2d(n_channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),\n", " Flatten(), nn.Linear(n_channels[3], num_classes)]\n", " self.features = nn.Sequential(*layers)\n", "\n", " def forward(self, x): return self.features(x)\n", "\n", "\n", "def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = wrn_22()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the way to create datasets in fastai" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds,valid_ds = ImageDataset.from_folder(PATH/'train'), ImageDataset.from_folder(PATH/'test')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cifar_mean,cifar_std = map(tensor, ([0.4914, 0.48216, 0.44653], [0.24703, 0.24349, 0.26159]))\n", "cifar_norm,cifar_denorm = normalize_funcs(cifar_mean,cifar_std)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_tfms = [pad(padding=4), \n", " crop(size=32, row_pct=(0,1), col_pct=(0,1)), \n", " flip_lr(p=0.5)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_ds, valid_ds, bs=512, train_tfm=train_tfms, tfms=cifar_norm, num_workers=8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A learner wraps together the data and the model, like in fastai. Here we test the usual training to 94% accuracy with AdamW." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = wrn_22()\n", "learn = Learner(data, model)\n", "learn.metrics = [accuracy]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "warm up" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value=''))), HTML(value='epoch train loss va…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:31\n", "epoch train loss valid loss accuracy\n", "0 1.337065 1.082762 0.609500 (00:31)\n", "\n" ] } ], "source": [ "learn.fit_one_cycle(1, 3e-3, wd=0.4, div_factor=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD8CAYAAABpcuN4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4FWX6//H3nZNGCQkkAUISSCChhBbgBAsWBAtWhEUFG23FVVh0rei6u6671lVZVsEVpYkFsYKugojoqouQhJqAgdATILQQavrz+yPj/vLNhuRAypxyv64rF2fmPPOcz2RC7szMMzNijEEppZQ6Ez+7AyillHJvWiiUUkrVSAuFUkqpGmmhUEopVSMtFEoppWqkhUIppVSNtFAopZSqkRYKpZRSNdJCoZRSqkb+dgeoDxERESYuLs7uGEop5VHS09MPGWMia2vnFYUiLi6OtLQ0u2MopZRHEZFdrrTTQ09KKaVqpIVCKaVUjbRQKKWUqpEWCqWUUjXSQqGUUqpGLhUKERkiIlkiki0iU6p5P0hE3rfeXyUicZXee8yanyUiV1nzgkVktYisF5FMEflzpfbxVh/ZVp+BdV9NpZRS56rWQiEiDmA6cDWQBIwSkaQqzcYD+caYBGAq8Ly1bBIwEugODAFmWP0VAYOMMb2BZGCIiJxv9fU8MNXqK9/qWymllE1cuY6iP5BtjNkOICILgKHApkpthgJPWq8/BF4VEbHmLzDGFAE7RCQb6G+MWQmcsNoHWF/GWmYQcKv13jyr39fOae2UcgPfbTlIRm4BoU0CaNEkgHahwXSLakGzIK+4jEn5AFd+UqOBPZWmc4DzztTGGFMqIgVAuDX/pyrLRsN/91TSgQRgujFmlYhEAEeNMaVV21clIhOACQDt27d3YTWUanxfZe7n7rfTqfpoehGIj2hGn9iWXNm9DZd2jiQ4wGFPSKVqYdufNMaYMiBZRMKAT0SkB7D/LJafCcwEcDqdppbmSjW6DTlHuW/BOnrFhDFvbArFpeUUnC5hT/4pMnKPkZFbwNeb8/hoTQ5NAhwM6taasRfG0a9DSyp2rpVyD64UilwgttJ0jDWvujY5IuIPhAKHXVnWGHNURFZQcQ7jJSBMRPytvYrqPkspt5eTf4pxc9MIbx7Im3c6CWtaMSajdYtgEtuEMKhrGwBKyspZtf0ISzL38dn6ffxrwz6SY8OYcElHhnRvi5+fFgxlP1dGPaUCidZopEAqTk4vrtJmMTDaej0C+MYYY6z5I61RUfFAIrBaRCKtPQlEpAlwBfCztcwKqw+sPhed++op1fgKTpcwbm4qRaVlzBmTQmRI0BnbBjj8uCgxgr/e2JOVjw3iqaHdyT9VzL3vrGHYjB9J33WkEZMrVb1aC4X1l/0kYCmwGVhojMkUkadE5Aar2Swg3DpZ/QAwxVo2E1hIxYnvJcBE65BTFLBCRDZQUYiWGWM+t/p6FHjA6ivc6lspj1BSVs6976Sz/eBJXr+9H4ltQlxetmmgP3deEMc3Dw7k5Zt7s/9YIb96bSWT31vLgeOFDZhaqZqJqXqWzQM5nU6jd49VdjPG8OhHG1iYlsOLN/VmRL+YOvV3sqiUf363jdf/vZ3mQf48O7wnV3VvW09plQIRSTfGOGtrp1dmK1VPZny7jYVpOUwelFDnIgHQLMifB6/swheTLyIqNJi756fz8AfrOVFUWvvCStUjLRRK1YNF63L529IshvWJ5ndXdK7XvhNah/DJvQO4d2AnPlqTw7DpP7Lr8Ml6/QylaqKFQqk6St15hIc/2ED/+FY896ueDTK0NdDfj0eGdGX++PM4eKKIG179kR+zD9X75yhVHS0UStXBjkMnueutNGJaNmHmHf0I8m/Yi+YGJESwaOIA2rQI4s7Zq3n7J5ceUKZUnWihUOocHTlZzNg5q/ETYc7YlP9eK9HQOoQ34+N7B3Bp50ie+DSD6Suy8YZBKcp9aaFQ6hwUlpQx4a009hYU8sad/egQ3qxRP795kD+v39GPG5Pb8belWTz75c9aLFSD0buSKXWWyssND3+4gbRd+Uy/tS/9OrSyJUeAw4+Xb06mRZMAZv57OyeKSnn6xh56+w9V77RQKHWWXlqWxWfr9/LokK5c2yvK1ix+fsKfb+hO00B//vndNoL8/fjjdUlaLFS90kKh1FlYmLqH6Su2Map/LL+5tKPdcQAQER4d0qXiliE/7qS5df2FUvVFC4VSLvph6yEe/2QjFydG8NRQ9zrEIyL88bokTheX8co32TQJdHDvwAS7YykvoYVCKRdk7T/OPW+nk9C6OTNu60uAw/3GgYgITw/ryemSMl5YkkVUaDDD+tT9CnGltFAoVYsDxwoZNzeVJoEOZo1JISQ4wO5IZ+TwE/42ojcHjhXxyIcbiAptwvkdw+2OpTyc+/1ZpJQbOVVcyvh5aeSfKmb2mBSiw5rYHalWgf5+/PP2frRv1ZQJb6WRfeBE7QspVQMtFEqdQVm5YfJ768jcW8Aro/rQIzrU7kguC20awNyx/Qn092Ps3NXknyy2O5LyYFoolDqDv/5rE19vzuNP13dncLc2dsc5a7GtmvLGnU7yCoqYvGAtZeV6QZ46N1oolKrG3B93MOfHnYwbEM/oC+PsjnPO+rRvyV9u7M73Ww/xt6VZdsdRHkoLhVJVfL0pj6c+38QVSW34/bXd7I5TZ7ektOfW89rzz++28cXGfXbHUR5IC4VSlWzMKeC3762lR3Qo00Ym4/Bzn2sl6uJP1yfRp30YD32wXk9uq7OmhUIpS+7R04ybl0qrZoG8OdpJ00DvGT0e5O/gtdv6ERzgYNK7aygsKbM7kvIgWiiUAo4VljBuTiqFJWXMGZtC65BguyPVu7ahwbx0c29+3n+cp/+12e44yoNooVA+r6SsnInvrGHbwRO8dls/OrcJsTtSg7msS2smXNKR+T/tYkmGnq9QrtFCoXyaMYY/fJrB91sP8czwnlyUGGF3pAb30JVd6B0TyiMfbiAn/5TdcZQH0EKhfNpr321jQeoefjsogZudsXbHaRSB/n68Mqov5QYe+mA95Xp9haqFFgrlsz5bv5cXlmQxNLkdD1zR2e44jap9eFP+eF0SP20/wpz/7LQ7jnJzLhUKERkiIlkiki0iU6p5P0hE3rfeXyUicZXee8yanyUiV1nzYkVkhYhsEpFMEbmvUvsnRSRXRNZZX9fUfTWV+r/Sdh7hwQ/WkxLXkhdG9HKrW4Y3lpucMVzerTXPL/mZrXnH7Y6j3FithUJEHMB04GogCRglIklVmo0H8o0xCcBU4Hlr2SRgJNAdGALMsPorBR40xiQB5wMTq/Q51RiTbH19Uac1VKqKnYdOctdbaUSHNWHmHU6C/B12R7KFiPDs8F40D/LngYXrKSkrtzuSclOu7FH0B7KNMduNMcXAAmBolTZDgXnW6w+BwVLxJ9pQYIExpsgYswPIBvobY/YZY9YAGGOOA5uB6LqvjlI1yz9ZzNi5qQDMGZNCy2aBNieyV2RIEM8M68nG3AJmrNhmdxzlplwpFNHAnkrTOfzvL/X/tjHGlAIFQLgry1qHqfoAqyrNniQiG0Rktoi0rC6UiEwQkTQRSTt48KALq6F8XVFpGXfPTyc3/zRv3OkkLqKZ3ZHcwpAebRma3I5XV2wla78eglL/y9aT2SLSHPgIuN8Yc8ya/RrQCUgG9gEvVbesMWamMcZpjHFGRkY2Sl7luYwxPPLhBlbvPMJLN/fGGdfK7khu5Y/XJRESHMAjH23Qu8yq/+FKocgFKo8bjLHmVdtGRPyBUOBwTcuKSAAVReIdY8zHvzQwxuQZY8qMMeXAG1Qc+lKqTl5etoVF6/by8FVduL53O7vjuJ3w5kE8eUN31u85ypwfd9gdR7kZVwpFKpAoIvEiEkjFyenFVdosBkZbr0cA3xhjjDV/pDUqKh5IBFZb5y9mAZuNMS9X7khEoipNDgMyznallKpsYdoeXvkmm1ucsdw7sJPdcdzW9b2iuLxba178Kotdh0/aHUe5kVoLhXXOYRKwlIqTzguNMZki8pSI3GA1mwWEi0g28AAwxVo2E1gIbAKWABONMWXAAOAOYFA1w2BfEJGNIrIBuAz4XX2trPI9P2Yf4vGPN3JRQgR/HdbDJ4fBukpE+OuNPQnw8+P3n2RQ8beeUiDe8MPgdDpNWlqa3TGUm9mad5zhr/2HdqFN+OCeC2gRHGB3JI8wf+VO/rAok2kjkxmarIMRvZmIpBtjnLW10yuzlVc6cLyQMXNSCQ5wMHtsihaJs3DreR3oHRPKXz7fTMHpErvjKDeghUJ5ndPFZdw1L40jJ4uZPTqF6LAmdkfyKA4/4elhPTlysogX9fGpCi0UysuUlRvuW7CWDbkF/GNUH3rGhNodySP1iA5l9IVxvL1qF+v2HLU7jrKZFgrlVZ75YjNfbcrjj9clcUVSG7vjeLQHruhM65Agnvh0o15b4eO0UCiv8dbKncz6YQdjB8QxdkC83XE8XkhwAL+/NomM3GMsSN1tdxxlIy0Uyiss35zHk4szubxbG564tuo9K9W5ur5XFOfFt+LFpVkcPVVsdxxlEy0UyuNl5BYw6d219IgO5R+jknH46bUS9UVEePKG7hScLuGlr7bYHUfZRAuF8mh7j55m3NxUWjUL5M3RTpoG+tsdyet0i2rBHed34J1Vu8jcW2B3HGUDLRTKYx0vLGHc3FROF5cxe0wKrUOC7Y7ktR64ogthTQN5cnGmXrHtg7RQKI9UUlbOve+sIfvACV67vR9d2obYHcmrhTYN4KEru5C6M58vNu63O45qZFoolMcxxvDHRRl8v/UQTw/rwUWJEXZH8gm3pMTStW0Iz365mcKSMrvjqEakhUJ5nH9+t533Vu/h3oGduCWlvd1xfIbDT/jDdUnk5J9mzo877Y6jGpEWCuVRPt+wl+eX/Mx1vaJ46MoudsfxOQMSIri8W2umr8jm4PEiu+OoRqKFQnmM9F1HeGDhepwdWvLiTb3x02Gwtnj8mm4UlpTx8jIdLusrtFAoj7Dr8EnueiuddqHBzLzTSXCAw+5IPqtjZHPuvCCO91N36zO2fYQWCuX2jp4qZuycVMqNYc7Y/rRqFmh3JJ83eXACzYP8ee7LzXZHUY1AC4Vya0WlZUx4K52c/NPMvMNJfEQzuyMpIKxpIBMvS2BF1kH+k33I7jiqgWmhUG7LGMOjH25g9c4j/O2mXvSPb2V3JFXJ6AvjiA5rwjNfbqZc7y7r1bRQKLc19eutfLpuLw9d2VkfyemGggMcPHRVZzJyj7F4/V6746gGpIVCuaUP03P4x/Kt3OyMYeJlCXbHUWcwtHc03du14G9Ls/QiPC+mhUK5nf9kH2LKRxu4KCGCp4f1RESHwborPz/h8Wu6kXv0NG//tMvuOKqBaKFQbmVr3nHufjudjpHNmHF7XwIc+iPq7gYkRHBxYgTTV2RzrLDE7jiqAej/QuU2Dh4vYuzcVIIDHMwek0KL4AC7IykXPXJVV/JPlfDGv7fbHUU1AC0Uyi2cLi7j1/NSOXyimFmjncS0bGp3JHUWesaEcm2vKN78fgcHjhfaHUfVM5cKhYgMEZEsEckWkSnVvB8kIu9b768SkbhK7z1mzc8SkausebEiskJENolIpojcV6l9KxFZJiJbrX9b1n01lTsrKzfc//5aNuQWMG1kMr1iwuyOpM7BQ1d2oaSsnFe/ybY7iqpntRYKEXEA04GrgSRglIhUfSjxeCDfGJMATAWet5ZNAkYC3YEhwAyrv1LgQWNMEnA+MLFSn1OA5caYRGC5Na282LNfbGZpZh5PXJvEld3b2h1HnaP4iGbckhLLu6t2s+vwSbvjqHrkyh5FfyDbGLPdGFMMLACGVmkzFJhnvf4QGCwVQ1WGAguMMUXGmB1ANtDfGLPPGLMGwBhzHNgMRFfT1zzgxnNbNeUJ5q/cyZs/7GD0BR0YNyDO7jiqju4bnIi/Q/j711vtjqLqkSuFIhrYU2k6h///S/1/2hhjSoECINyVZa3DVH2AVdasNsaYfdbr/UCb6kKJyAQRSRORtIMHD7qwGsrdfPNzHn9anMngrq354/XddRisF2jdIpjRF8bx6bpctuTpDQO9ha0ns0WkOfARcL8x5ljV903Fw3mrvTeAMWamMcZpjHFGRkY2cFJV3zJyC5j07lqS2rXgH6P64NBbhnuN31zSieaB/rz0VZbdUVQ9caVQ5AKxlaZjrHnVthERfyAUOFzTsiISQEWReMcY83GlNnkiEmW1iQIOuLoyyjPsKzjN+HmphDUJYPboFJoF+dsdSdWjls0C+fXFHVmamceGnKN2x1H1wJVCkQokiki8iARScXJ6cZU2i4HR1usRwDfW3sBiYKQ1KioeSARWW+cvZgGbjTEv19DXaGDR2a6Ucl/HC0sYOyeVk0VlzB6bQusWwXZHUg1g3EVxtGwawItf6cONvEGthcI65zAJWErFSeeFxphMEXlKRG6wms0CwkUkG3gAa6SSMSYTWAhsApYAE40xZcAA4A5gkIiss76usfp6DrhCRLYCl1vTyguUlpUz6d21bD1wghm39aVr2xZ2R1INJCQ4gHsGduLfWw6yavthu+OoOpKKP/w9m9PpNGlpaXbHUDUwxvDEpxm8s2o3zw3vycj+7e2OpBpYYUkZl7ywgriIZrw/4XwdrOCGRCTdGOOsrZ1ema0axRvfb+edVbu5Z2AnLRI+IjjAwcTLEli94wg/ZutehSfTQqEa3Bcb9/HMFz9zba8oHr6yi91xVCMa2T+WdqHBvLQsC284euGrtFCoBrVmdz6/e38d/Tq05KWbeuOnw2B9SpC/g0mDElm7+yjfZun1Tp5KC4VqMLsPn+KueWm0DQ3mjTudBAc47I6kbHCTM4bYVk14edkW3avwUFooVIM4eqqYMXNXU2YMc8ak0KpZoN2RlE0CHH5MHpTIxtwCvtqUZ3ccdQ60UKh6V1Raxt3z08k5cprXb+9Hx8jmdkdSNhvWJ5qOEc2YumwL5eW6V+FptFCoemWM4bGPNrJqxxH+dlMvzusYbnck5Qb8HX5MHpzIz/uPszRzv91x1FnSQqHq1d+/3srHa3N58IrODE2ueu9I5cuu792OTpHN+PvXW3WvwsNooVD15qP0HKYt38qIfjFMGpRgdxzlZhx+wuTBiWTlHefLDN2r8CRaKFS9WLntMFM+3sCFncJ5ZlhPvQpXVeu6Xu1IaN2cacv1XIUn0UKh6iz7wHHunp9GXHgzXru9H4H++mOlqufwE+4bnMiWvBP8a+O+2hdQbkH/R6s6OXSiiLFzUwn0dzB7TAqhTQLsjqTc3LU9o+jcpjnTlm+lTPcqPIIWCnXOCkvK+PW8NA4eL+LN0U5iWzW1O5LyAH7WuYrsAyf4QvcqPIIWCnVOyssN9y9Yx/qco0wb2Yfk2DC7IykPck2PKBJbN+cfy3UElCfQQqHOybNfbmZJ5n5+f003rure1u44ysP8slex9cAJvsjQvQp3p4VCnbX5P+3ije93cOcFHRh/UbzdcZSHuqZnxV7FNL2uwu1poVBnZcXPB/jTogwGdW3NH69L0mGw6pw5dK/CY2ihUC7btPcYk95dQ7eoFrwyqg/+Dv3xUXXzy16Fnqtwb/o/Xblkf0Eh4+am0qJJALPHpNAsyN/uSMoLOPyE31rXVSzRe0C5LS0UqlYnikoZOzeV44UlzB6TQpsWwXZHUl7k2p5RdIpspnsVbkwLhapRaVk5k95dw5a840y/rS/dolrYHUl5GYef8NtBFXeW/WqT7lW4Iy0U6oyMMTz5WSbfZh3kL0N7MLBLa7sjKS91fe92dIzQO8u6Ky0U6oze/H4Hb/+0m7sv7cit57W3O47yYg4/YdKgBGuvQp+C5260UKhqfblxH898uZlrerbl0au62h1H+YAbercjLrwp/1i+VZ+t7WZcKhQiMkREskQkW0SmVPN+kIi8b72/SkTiKr33mDU/S0SuqjR/togcEJGMKn09KSK5IrLO+rrm3FdPnYu1u/O5//119IkN4+Wbk/Hz02slVMPzd/gxaVAim/Yd4+vNB+yOoyqptVCIiAOYDlwNJAGjRCSpSrPxQL4xJgGYCjxvLZsEjAS6A0OAGVZ/AHOtedWZaoxJtr6+OLtVUnWx58gp7norjTYtgnnjTifBAY7aF1KqntyY3I72rZoybfkW3atwI67sUfQHso0x240xxcACYGiVNkOBedbrD4HBUnHJ7lBggTGmyBizA8i2+sMY82/gSD2sg6onBadKGDNnNSVlhjljUwhvHmR3JOVj/B1+TLosgYzcY6zI0r0Kd+FKoYgG9lSazrHmVdvGGFMKFADhLi5bnUkissE6PNWyugYiMkFE0kQk7eDBgy50qWpSXFrOb95OZ/eRU8y8ox+dIpvbHUn5qGF9o4lp2YRpX+u5CnfhjiezXwM6AcnAPuCl6hoZY2YaY5zGGGdkZGRj5vM6xhimfLyBldsP88KIXpzXMdzuSMqHBTj8mHhZAutzCvhui/4R6A5cKRS5QGyl6RhrXrVtRMQfCAUOu7js/2GMyTPGlBljyoE3sA5VqYYzbflWPl6Ty+8u78ywPjF2x1GKX/WNITqsCdN0BJRbcKVQpAKJIhIvIoFUnJxeXKXNYmC09XoE8I2p2LqLgZHWqKh4IBFYXdOHiUhUpclhQMaZ2qq6+3hNDn//eivD+0YzeXCC3XGUAiDQ3497BnZi7e6jfL/1kN1xfF6thcI65zAJWApsBhYaYzJF5CkRucFqNgsIF5Fs4AFgirVsJrAQ2AQsASYaY8oAROQ9YCXQRURyRGS81dcLIrJRRDYAlwG/q6d1VVWs3HaYRz/awAUdw3lueC+9ZbhyKzc5Y4gKDda9Cjcg3rABnE6nSUtLszuGR8k+cILhM34kMiSIj+8ZQGjTALsjKfU/5q/cyR8WZfL2+PO4KDHC7jheR0TSjTHO2tq548ls1cAOnShi3NxUAv39mDu2vxYJ5bZuTomlbYtgva7CZloofExhSRl3vZVG3rFC3rjTSWyrpnZHUuqMgvwd3DOwE6k781m5/bDdcXyWFgofUl5ueGDhOtbtOcq0kcn0aV/tJSpKuZVbUmJp0yKIaV9vtTuKz9JC4UOeX/ozX2zcz+NXd2NIj6jaF1DKDQQHOPjNpZ1YteMIK7fpXoUdtFD4iHdX7eb177Zz+/nt+fXF8XbHUeqsjOrfnsiQIKYt32J3FJ+khcIHfJt1gD8syuCyLpE8eX13HQarPE5wgIN7Lu3ET9uP8JOeq2h0Wii83OZ9x5j07lq6tAnhlVv74u/QTa48063nWXsVeq6i0elvDS+Wd6yQcXNTaR7kz+wxKTQP8rc7klLn7Je9ipXbD+teRSPTQuGlThaVMm5uKsdOlzB7TAptQ4PtjqRUnelehT20UHih0rJyfvveWn7ef5zpt/UlqV0LuyMpVS9+GQG1cvthVuleRaPRQuFljDH8+bNNfPPzAf58Q3cGdmltdySl6tVt1l7F33WvotFoofAys37YwfyfdjHhko7cfn4Hu+MoVe/0XEXj00LhRZZk7OfpLzZzdY+2TBnS1e44SjWYW89rT+uQIF5epveAagxaKLzEuj1Huf/9tfSOCWPqLcn4+em1Esp7BQc4uHdgJ1br1dqNQguFF9hz5BS/npdKZEgQb452EhzgsDuSUg1uZP/2tG0RzNSvda+ioWmh8HAFp0sYOzeV4tJy5oxJIaJ5kN2RlGoUwQEO7r2s4s6yP2brXkVD0kLhwYpLy7nn7XR2HT7J63c4SWgdYnckpRrVLSmxRIUG8/KyLN2raEBaKDyUMYbHP9nIf7Yd5rnhvbigU7jdkZRqdEH+DiYNSmDN7qN8u+Wg3XG8lhYKD/XqN9l8mJ7D/Zcn8qt+MXbHUco2N/WLJaZlE6bqCKgGo4XCAy1al8tLy7YwvG809w1OtDuOUrYK9Pdj8uBENuQUsGxTnt1xvJIWCg+zescRHv5gA+d3bMVzw3vpLcOVAob3iSYuvCkvL9tCebnuVdQ3LRQeZNvBE0yYn0ZMqya8fruTQH/dfEoB+Dv8uO/yRH7ef5wvM/bbHcfr6G8aD3H4RBFj56TiEGHumP6ENg2wO5JSbuWG3tEktG7O1K+3UKZ7FfVKC4UHKCwp46630sg7Vsgbo520D29qdySl3I7DT3jgis5kHzjBp2tz7Y7jVVwqFCIyRESyRCRbRKZU836QiLxvvb9KROIqvfeYNT9LRK6qNH+2iBwQkYwqfbUSkWUistX6t+W5r57nKy83PLhwPWv3HOXvtyTTt71PfzuUqtGQ7m3pEd2CqV9vobi03O44XqPWQiEiDmA6cDWQBIwSkaQqzcYD+caYBGAq8Ly1bBIwEugODAFmWP0BzLXmVTUFWG6MSQSWW9M+64WlWfxr4z4eu7orV/eMsjuOUm7Nz0948Mou5OSf5v20PXbH8Rqu7FH0B7KNMduNMcXAAmBolTZDgXnW6w+BwVIxHGcosMAYU2SM2QFkW/1hjPk3cKSaz6vc1zzgxrNYH6/y3urd/PO7bdx2Xnvuurij3XGU8ggDO0fi7NCSV5Zv5XRxmd1xvIIrhSIaqFyac6x51bYxxpQCBUC4i8tW1cYYs896vR9oU10jEZkgImkiknbwoPddkfndloM88WkGA7tE8ucbuuswWKVcJCI8fFUXDhwvYv5PO+2O4xXc+mS2qbjMstrhC8aYmcYYpzHGGRkZ2cjJGtbmfceY+M4aOrcJ4dVb++LvcOvNpJTbOa9jOBcnRjDj220cKyyxO47Hc+U3UC4QW2k6xppXbRsR8QdCgcMuLltVnohEWX1FAQdcyOg18o4VMm5uKs2CHMwe46R5kL/dkZTySI8O6crRUyXM/G673VE8niuFIhVIFJF4EQmk4uT04iptFgOjrdcjgG+svYHFwEhrVFQ8kAisruXzKvc1GljkQkavcLKolHFzUyk4XcKs0SlEhTaxO5JSHqtHdCjX9Ypi1g87OHC80O44Hq3WQmGdc5gELAU2AwuNMZki8pSI3GA1mwWEi0g28ADWSCVjTCawENgELAEmGmPKAETkPWAl0EVEckRkvNXXc8AVIrIVuNya9npl5YbJ761ap37gAAAQ/klEQVRl875jTL+1Lz2iQ+2OpJTHe+jKLpSUlfPK8my7o3g08Ya7LTqdTpOWlmZ3jHNmjOHJxZnMW7mLvwztzh0XxNkdSSmv8cSnG1mweg9fP3ApcRHN7I7jVkQk3RjjrK2dniV1A7N/3Mm8lbu46+J4LRJK1bPJgxIJcPjx0rItdkfxWFoobPZV5n7++q9NDOnelseu7mZ3HKW8TusWwYy/KJ7P1u9lY06B3XE8khYKG63fc5TJC9bSKyaMqbck4+en10oo1RDuvrQjrZoF8swXm/XhRudAC4VN9hw5xfh5aUQ0D+LNO500CXTUvpBS6pyEBAdw3+BEVm4/zLdZ3neBbkPTQmGDgtMljJubSlFpGXPGpBAZEmR3JKW83q3ntSc+ohnPfrmZ0jK9YeDZ0ELRyIpLy7n3nXR2HDrJ67f3I7FNiN2RlPIJAQ4/HrmqC1vyTvDRmhy743gULRSNyBjD7z/ZyI/Zh3l2eE8uTIiwO5JSPmVIj7b0bR/GS19t4VRxqd1xPIYWikY0fUU2H6TnMHlQAjc5Y2tfQClVr0SE31+bxIHjRfxTb+3hMi0UjWTRulxe/GoLNya343dXdLY7jlI+q1+HllzXK4qZ/97G3qOn7Y7jEbRQNILVO47w8Acb6B/fiudH9NJbhitlsylXd6XcwAtLfrY7ikfQQtHAdhw6yYT5acS0asLMO/oR5K/DYJWyW0zLptx1cTyfrtvLuj1H7Y7j9rRQNKAjJ4sZO2c1fiLMGZNCWNNAuyMppSz3DEwgMiSIpz7L1IvwaqGFooEUlpQx4a009hYU8sadTjqE683IlHInzYP8efjKLqzZfZRF6/baHcetaaFoAOXlhoc+WE/arnym3pxMvw4t7Y6klKrGiH4x9IoJ5ZkvNnOiSIfLnokWigbw0rIsPt+wjylXd+XaXlF2x1FKnYGfn/DnG7pz4HgRr3yz1e44bksLRT1bmLqH6Su2Map/e+6+pKPdcZRStejTviUj+sUw+4cdbDt4wu44bkkLRT36YeshHv9kI5d0juQvQ7vrMFilPMSjQ7oS7O/gqc826YntamihqCdZ+49zz9vpJLRuzvRb++Dv0G+tUp4iMiSI+6/ozHdbDvLVpjy747gd/W1WDw4cK2TsnNU0DXIwe0wKIcEBdkdSSp2lOy/oQJc2Ifx5cSYn9cT2/6GFoo5OFZcyfl4aR0+XMGt0Cu3CmtgdSSl1DgIcfjw9rAd7CwqZtlxPbFemhaIOysoNk99bS+beAl69tQ89okPtjqSUqgNnXCtGpsQy64cd/Lz/mN1x3IYWijr4y+eb+HrzAf50fXcGdW1jdxylVD14dEhXQpsE8PtPMigv1xPboIXinM35cQdz/7OTcQPiGX1hnN1xlFL1pGWzQB6/phvpu/JZkLrH7jhuwaVCISJDRCRLRLJFZEo17weJyPvW+6tEJK7Se49Z87NE5Kra+hSRuSKyQ0TWWV/JdVvF+rdsUx5Pfb6JK5La8Ptru9kdRylVz37VN5oLOobz7Beb2V9QaHcc29VaKETEAUwHrgaSgFEiklSl2Xgg3xiTAEwFnreWTQJGAt2BIcAMEXG40OfDxphk62tdndawnm3MKWDye2vpGR3KtJHJOPz0WgmlvI2I8OzwnpSUl/PEpxt9/toKV/Yo+gPZxpjtxphiYAEwtEqbocA86/WHwGCpuNpsKLDAGFNkjNkBZFv9udKn28k9eppx81Jp1SyQN0c7aRrob3ckpVQDiYtoxoNXdOHrzQf4bMM+u+PYypVCEQ1UPlCXY82rto0xphQoAMJrWLa2Pp8WkQ0iMlVEglzI2OCOFZYwds5qCkvKmDs2hdYhwXZHUko1sHEXxdM7NownF2dy+ESR3XFs444nsx8DugIpQCvg0eoaicgEEUkTkbSDBw82aKCSsnLufXsN2w+e5J+39yOxTUiDfp5Syj04/IQXftWL44Ul/Glxpt1xbONKocgFYitNx1jzqm0jIv5AKHC4hmXP2KcxZp+pUATMoeIw1f8wxsw0xjiNMc7IyEgXVuPcGGN44pMMfsg+xLPDezIgIaLBPksp5X66tA3hvsGJfL5hH4vX++ZzK1wpFKlAoojEi0ggFSenF1dpsxgYbb0eAXxjKs7+LAZGWqOi4oFEYHVNfYpIlPWvADcCGXVZwbqa8e023k/bw28HJXCTM7b2BZRSXuc3l3YiOTaMP3ya4ZOjoGotFNY5h0nAUmAzsNAYkykiT4nIDVazWUC4iGQDDwBTrGUzgYXAJmAJMNEYU3amPq2+3hGRjcBGIAL4a/2s6tlbvH4vf1uaxdDkdjxwRWe7YiilbObv8OPlm3tTVFrGIx9t8LlRUOINK+x0Ok1aWlq99pm68wi3vbGK5Ngw5v+6P0H+jnrtXynleeav3MkfFmXylxt7cMf5HeyOU2cikm6McdbWzh1PZttux6GTTHgrjeiWTXj9jn5aJJRSANx+fgcu6RzJ0//axJa843bHaTRaKKrIP1nM2DmrAZgzJoWWzQJtTqSUchciwos39aJ5kD+T3l3D6eIyuyM1Ci0UlRSWlDFhfhp7Cwp5c7STuIhmdkdSSrmZ1iHBTL0lmS15J3jq8012x2kUWigs5eWGRz7cQOrOfF6+uTf9OrSyO5JSyk1dnBjJPQM78d7q3Xy+wfuHzGqhsLy8bAuL1+/lkSFduK5XO7vjKKXc3ANXdKZv+zCmfLSR7AMn7I7ToLRQAAtT9/DqimxGpsRyz6Wd7I6jlPIAAQ4/Xr21L0H+ftw9P43jhSV2R2owPl8ofth6iMc/2cjFiRH85cYeVFznp5RStWsX1oRXb+3LzsOneHDheq990JFPF4qs/ce55+10OkU2Z/ptfQlw+PS3Qyl1Di7oFM7j13Tjq015zPg22+44DcKnfzO++f12ggMdzB6bQovgALvjKKU81LgBcdyY3I6Xlm1hScZ+u+PUO59+oMIzw3uy9+hposOa2B1FKeXBRITnftWLXUdOcf/7a1kQegHJsWF2x6o3Pr1HEeDwo0O4XiuhlKq74AAHb9zpJDIkiF/PS2XPkVN2R6o3Pl0olFKqPkU0D2LOmP6UlBnGzk0l/2Sx3ZHqhRYKpZSqRwmtm/P6Hf3YfeQUo+es9ophs1oolFKqnp3fMZzXbuvLpr3HGDc3lVPFpXZHqhMtFEop1QAGd2vD30cmk74rn7vnp1NY4rk3ENRCoZRSDeS6Xu14YURvvt96iHFzUzlR5Jl7FloolFKqAY3oF8PLN/dm1Y4j3PbmKo6e8rwT3FoolFKqgQ3vG8Nrt/Vl875j3Pz6So977rYWCqWUagRXdm/L3LEp5Oaf5oZXf2Dt7ny7I7lMC4VSSjWSCztF8NG9FxIU4MctM3/io/QcuyO5RAuFUko1oq5tW7Bo4kX0a9+SBz9Yz58WZbj9iCgtFEop1chaNQvkrfH9GTcgnnkrd3HdKz+QkVtgd6wz0kKhlFI2CHD48cfrk5g/vj/HC0u4cfqPTF22hdPF7rd3oYVCKaVsdHFiJEvvv4Sre0YxbflWLn/5Oz5bvxdj3OchSC4VChEZIiJZIpItIlOqeT9IRN633l8lInGV3nvMmp8lIlfV1qeIxFt9ZFt9BtZtFZVSyr2FNQ3klVF9WDDhfFo0CeC3761l2Iz/sCRjH2Vu8NS8WguFiDiA6cDVQBIwSkSSqjQbD+QbYxKAqcDz1rJJwEigOzAEmCEijlr6fB6YavWVb/WtlFJe7/yO4Xz+24t4bnhPDp8s4jdvr2HQS98y+4cd5B2z79oLqW33RkQuAJ40xlxlTT8GYIx5tlKbpVablSLiD+wHIoEpldv+0s5a7H/6BJ4DDgJtjTGlVT/7TJxOp0lLS3N5pZVSyt2VlRuWZu7nje+3s3b3UQD6tg/jyu5tSY4NI6ldizo/mVNE0o0xztraufKEu2hgT6XpHOC8M7WxfsEXAOHW/J+qLBttva6uz3DgqDGmtJr2SinlMxx+wjU9o7imZxTZB46zJGM/X2bs57kvf/5vm7jwpjw7vBcXdApv0Cwe+yhUEZkATABo3769zWmUUqrhJLQOYdKgECYNSuTg8SIy9haQmVtA5t5jRIY0/GlcVwpFLhBbaTrGmlddmxzr0FMocLiWZaubfxgIExF/a6+ius8CwBgzE5gJFYeeXFgPpZTyeJEhQVzWpTWXdWndaJ/pyqinVCDRGo0USMXJ6cVV2iwGRluvRwDfmIqTH4uBkdaoqHggEVh9pj6tZVZYfWD1uejcV08ppVRd1bpHYZ1zmAQsBRzAbGNMpog8BaQZYxYDs4D5IpINHKHiFz9Wu4XAJqAUmGiMKQOork/rIx8FFojIX4G1Vt9KKaVsUuuoJ0+go56UUursuTrqSa/MVkopVSMtFEoppWqkhUIppVSNtFAopZSqkRYKpZRSNfKKUU8ichDYdY6LRwCH6jGOJ9F19z2+ut6g617duncwxkTWtrBXFIq6EJE0V4aHeSNdd99bd19db9B1r8u666EnpZRSNdJCoZRSqkZaKKwbC/ooXXff46vrDbru58znz1EopZSqme5RKKWUqpFPFwoRGSIiWSKSLSJT7M7TUEQkVkRWiMgmEckUkfus+a1EZJmIbLX+bWl31oZiPat9rYh8bk3Hi8gqa9u/b93u3uuISJiIfCgiP4vIZhG5wFe2u4j8zvp5zxCR90Qk2Fu3u4jMFpEDIpJRaV6121kq/MP6HmwQkb619e+zhUJEHMB04GogCRglIkn2pmowpcCDxpgk4HxgorWuU4DlxphEYLk17a3uAzZXmn4emGqMSQDygfG2pGp404AlxpiuQG8qvgdev91FJBqYDDiNMT2oeJzBSLx3u88FhlSZd6btfDUVzwZKpOIpoa/V1rnPFgqgP5BtjNlujCkGFgBDbc7UIIwx+4wxa6zXx6n4ZRFNxfrOs5rNA260J2HDEpEY4FrgTWtagEHAh1YTr1x3EQkFLsF6posxptgYcxQf2e5UPG+nifXUzabAPrx0uxtj/k3Fs4AqO9N2Hgq8ZSr8RMVTRaNq6t+XC0U0sKfSdI41z6uJSBzQB1gFtDHG7LPe2g+0sSlWQ/s78AhQbk2HA0etx+2C9277eOAgMMc67PamiDTDB7a7MSYXeBHYTUWBKADS8Y3t/oszbeez/t3ny4XC54hIc+Aj4H5jzLHK71mPofW6IXAich1wwBiTbncWG/gDfYHXjDF9gJNUOczkxdu9JRV/OccD7YBm/O+hGZ9R1+3sy4UiF4itNB1jzfNKIhJARZF4xxjzsTU775ddTuvfA3bla0ADgBtEZCcVhxcHUXHcPsw6JAHeu+1zgBxjzCpr+kMqCocvbPfLgR3GmIPGmBLgYyp+Fnxhu//iTNv5rH/3+XKhSAUSrVEQgVSc6Fpsc6YGYR2TnwVsNsa8XOmtxcBo6/VoYFFjZ2toxpjHjDExxpg4KrbxN8aY24AVwAirmbeu+35gj4h0sWYNpuL59V6/3ak45HS+iDS1fv5/WXev3+6VnGk7LwbutEY/nQ8UVDpEVS2fvuBORK6h4vi1A5htjHna5kgNQkQuAr4HNvL/j9M/TsV5ioVAeyruvnuzMabqCTGvISIDgYeMMdeJSEcq9jBaAWuB240xRXbmawgikkzFSfxAYDswloo/EL1+u4vIn4FbqBj1txb4NRXH4r1uu4vIe8BAKu4Smwf8CfiUarazVThfpeJQ3ClgrDEmrcb+fblQKKWUqp0vH3pSSinlAi0USimlaqSFQimlVI20UCillKqRFgqllFI10kKhlFKqRloolFJK1UgLhVJKqRr9P/cObhHJVPWcAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## FP16" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same but in mixed-precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = wrn_22()\n", "model = model2half(model)\n", "learn = Learner(data, model)\n", "learn.metrics = [accuracy]\n", "learn.callbacks.append(MixedPrecision(learn))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=25), HTML(value=''))), HTML(value='epoch train loss v…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 06:00\n", "epoch train loss valid loss accuracy\n", "0 1.430891 1.398491 0.501000 (00:14)\n", "1 1.061705 0.986364 0.644000 (00:14)\n", "2 0.838563 1.008399 0.657300 (00:14)\n", "3 0.695539 0.791527 0.730300 (00:14)\n", "4 0.605823 0.812625 0.732100 (00:14)\n", "5 0.534148 0.623167 0.789800 (00:14)\n", "6 0.496663 0.826545 0.736900 (00:14)\n", "7 0.461931 0.654031 0.781400 (00:14)\n", "8 0.441907 0.614082 0.799900 (00:14)\n", "9 0.407904 0.626947 0.795200 (00:14)\n", "10 0.390932 0.522992 0.826900 (00:14)\n", "11 0.377011 0.682757 0.787800 (00:14)\n", "12 0.361617 0.942910 0.721000 (00:14)\n", "13 0.338757 0.722718 0.783000 (00:14)\n", "14 0.314459 0.601760 0.812800 (00:14)\n", "15 0.280130 0.491810 0.841900 (00:14)\n", "16 0.254688 0.497176 0.845200 (00:14)\n", "17 0.218617 0.348053 0.885900 (00:14)\n", "18 0.188661 0.304984 0.899300 (00:14)\n", "19 0.146677 0.272370 0.914500 (00:14)\n", "20 0.105810 0.250875 0.920300 (00:14)\n", "21 0.074922 0.229407 0.931700 (00:14)\n", "22 0.052838 0.218979 0.935900 (00:14)\n", "23 0.036591 0.217552 0.938000 (00:14)\n", "24 0.030799 0.214543 0.939900 (00:14)\n", "\n", "CPU times: user 3min 54s, sys: 1min 36s, total: 5min 31s\n", "Wall time: 6min\n" ] } ], "source": [ "%time learn.fit_one_cycle(25, 3e-3, wd=0.4, div_factor=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value=''))), HTML(value='epoch train loss v…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 07:11\n", "epoch train loss valid loss accuracy\n", "0 1.393063 1.417441 0.493300 (00:17)\n", "1 1.027538 1.088767 0.616700 (00:14)\n", "2 0.836060 0.941692 0.683400 (00:14)\n", "3 0.688142 0.690871 0.759500 (00:14)\n", "4 0.594530 1.033483 0.675200 (00:14)\n", "5 0.542689 0.816458 0.725300 (00:14)\n", "6 0.498397 0.887681 0.721900 (00:14)\n", "7 0.449883 0.661108 0.785300 (00:14)\n", "8 0.434427 0.634441 0.786900 (00:14)\n", "9 0.410082 0.609202 0.785200 (00:14)\n", "10 0.380734 0.644742 0.794400 (00:14)\n", "11 0.375207 0.616527 0.795500 (00:14)\n", "12 0.364760 0.578489 0.805400 (00:14)\n", "13 0.349154 0.496333 0.835500 (00:14)\n", "14 0.338663 0.626630 0.794900 (00:14)\n", "15 0.336282 0.674507 0.787300 (00:14)\n", "16 0.307156 0.471588 0.835900 (00:14)\n", "17 0.285396 0.436213 0.855200 (00:14)\n", "18 0.262049 0.426600 0.853800 (00:14)\n", "19 0.237897 0.357533 0.880400 (00:14)\n", "20 0.219161 0.369858 0.882100 (00:14)\n", "21 0.186900 0.320745 0.898100 (00:14)\n", "22 0.149875 0.305023 0.905600 (00:14)\n", "23 0.118723 0.251503 0.920200 (00:14)\n", "24 0.094528 0.249991 0.921800 (00:14)\n", "25 0.063562 0.224991 0.934100 (00:14)\n", "26 0.044421 0.225542 0.935900 (00:14)\n", "27 0.030114 0.214371 0.938900 (00:14)\n", "28 0.022671 0.211160 0.943600 (00:14)\n", "29 0.017958 0.210311 0.942600 (00:14)\n", "\n", "CPU times: user 4min 41s, sys: 1min 55s, total: 6min 36s\n", "Wall time: 7min 11s\n" ] } ], "source": [ "%time learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }