{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|default_exp resnet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# ResNets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt\n", "import fastcore.all as fc\n", "from collections.abc import Mapping\n", "from pathlib import Path\n", "from operator import attrgetter,itemgetter\n", "from functools import partial\n", "from copy import copy\n", "from contextlib import contextmanager\n", "\n", "import torchvision.transforms.functional as TF,torch.nn.functional as F\n", "from torch import tensor,nn,optim\n", "from torch.utils.data import DataLoader,default_collate\n", "from torch.nn import init\n", "from torch.optim import lr_scheduler\n", "from torcheval.metrics import MulticlassAccuracy\n", "from datasets import load_dataset,load_dataset_builder\n", "\n", "from miniai.datasets import *\n", "from miniai.conv import *\n", "from miniai.learner import *\n", "from miniai.activations import *\n", "from miniai.init import *\n", "from miniai.sgd import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastcore.test import test_close\n", "\n", "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", "torch.manual_seed(1)\n", "mpl.rcParams['image.cmap'] = 'gray'\n", "\n", "import logging\n", "logging.disable(logging.WARNING)\n", "\n", "set_seed(42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "10218e9f603241a7906a86c053b3b4dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00\n", " /* Turns off some styling */\n", " progress {\n", " /* gets rid of default border in Firefox and Opera. */\n", " border: none;\n", " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", " background-size: auto;\n", " }\n", " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", " background: #F44336;\n", " }\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracylossepochtrain
0.8060.7030train
0.8470.4560eval
0.8840.3331train
0.8560.4151eval
0.9060.2632train
0.8820.3252eval
0.9230.2153train
0.9100.2513eval
0.9400.1704train
0.9170.2324eval
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(epochs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Skip Connections" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ResNet (*residual network*) was introduced in 2015 by Kaiming He et al in the article [\"Deep Residual Learning for Image Recognition\"](https://arxiv.org/abs/1512.03385). The key idea is using a *skip connection* to allow deeper networks to train successfully." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#|export\n", "def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n", " return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n", " conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n", "\n", "class ResBlock(nn.Module):\n", " def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n", " super().__init__()\n", " self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n", " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n", " self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n", " self.act = act()\n", "\n", " def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Post-lesson update**: Piotr Czapla noticed that we forgot to include `norm=norm` in the call to `_conv_block` above, so the resnets in the lesson didn't have batchnorm in the resblocks! After fixing this, we discovered that initializing the `conv2` batchnorm weights to zero makes things worse in every model we tried, so we removed that. That init method was originally introduced to handle training extremely deep models (much deeper than we use here) -- it appears from this little test that it might be worse for less deep models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_model(act=nn.ReLU, nfs=(8,16,32,64,128,256), norm=nn.BatchNorm2d):\n", " layers = [ResBlock(1, 8, stride=1, act=act, norm=norm)]\n", " layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)]\n", " layers += [nn.Flatten(), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)]\n", " return nn.Sequential(*layers).to(def_device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ResBlock torch.Size([1024, 1, 28, 28]) torch.Size([1024, 8, 28, 28])\n", "ResBlock torch.Size([1024, 8, 28, 28]) torch.Size([1024, 16, 14, 14])\n", "ResBlock torch.Size([1024, 16, 14, 14]) torch.Size([1024, 32, 7, 7])\n", "ResBlock torch.Size([1024, 32, 7, 7]) torch.Size([1024, 64, 4, 4])\n", "ResBlock torch.Size([1024, 64, 4, 4]) torch.Size([1024, 128, 2, 2])\n", "ResBlock torch.Size([1024, 128, 2, 2]) torch.Size([1024, 256, 1, 1])\n", "Flatten torch.Size([1024, 256, 1, 1]) torch.Size([1024, 256])\n", "Linear torch.Size([1024, 256]) torch.Size([1024, 10])\n", "BatchNorm1d torch.Size([1024, 10]) torch.Size([1024, 10])\n" ] } ], "source": [ "def _print_shape(hook, mod, inp, outp): print(type(mod).__name__, inp[0].shape, outp.shape)\n", "model = get_model()\n", "learn = TrainLearner(model, dls, F.cross_entropy, cbs=[DeviceCB(), SingleBatchCB()])\n", "with Hooks(model, _print_shape) as hooks: learn.fit(1, train=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@fc.patch\n", "def summary(self:Learner):\n", " res = '|Module|Input|Output|Num params|\\n|--|--|--|--|\\n'\n", " tot = 0\n", " def _f(hook, mod, inp, outp):\n", " nonlocal res,tot\n", " nparms = sum(o.numel() for o in mod.parameters())\n", " tot += nparms\n", " res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|\\n'\n", " with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, train=False, cbs=SingleBatchCB())\n", " print(\"Tot params: \", tot)\n", " if fc.IN_NOTEBOOK:\n", " from IPython.display import Markdown\n", " return Markdown(res)\n", " else: print(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tot params: 1228908\n" ] }, { "data": { "text/markdown": [ "|Module|Input|Output|Num params|\n", "|--|--|--|--|\n", "|ResBlock|(1024, 1, 28, 28)|(1024, 8, 28, 28)|712|\n", "|ResBlock|(1024, 8, 28, 28)|(1024, 16, 14, 14)|3696|\n", "|ResBlock|(1024, 16, 14, 14)|(1024, 32, 7, 7)|14560|\n", "|ResBlock|(1024, 32, 7, 7)|(1024, 64, 4, 4)|57792|\n", "|ResBlock|(1024, 64, 4, 4)|(1024, 128, 2, 2)|230272|\n", "|ResBlock|(1024, 128, 2, 2)|(1024, 256, 1, 1)|919296|\n", "|Flatten|(1024, 256, 1, 1)|(1024, 256)|0|\n", "|Linear|(1024, 256)|(1024, 10)|2560|\n", "|BatchNorm1d|(1024, 10)|(1024, 10)|20|\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "TrainLearner(get_model(), dls, F.cross_entropy, cbs=DeviceCB()).summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)\n", "MomentumLearner(model, dls, F.cross_entropy, cbs=DeviceCB()).lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2\n", "tmax = epochs * len(dls.train)\n", "sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)\n", "xtra = [BatchSchedCB(sched)]\n", "model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw)\n", "learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracylossepochtrain
0.8270.6810train
0.8140.6090eval
0.8940.3521train
0.8870.3261eval
0.9130.2622train
0.9100.2652eval
0.9340.1963train
0.9220.2273eval
0.9500.1494train
0.9280.2114eval
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import timm\n", "from timm.models.resnet import BasicBlock, ResNet, Bottleneck" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'cspresnet50 cspresnet50d cspresnet50w eca_resnet33ts ecaresnet26t ecaresnet50d ecaresnet50d_pruned ecaresnet50t ecaresnet101d ecaresnet101d_pruned ecaresnet200d ecaresnet269d ecaresnetlight ens_adv_inception_resnet_v2 gcresnet33ts gcresnet50t gluon_resnet18_v1b gluon_resnet34_v1b gluon_resnet50_v1b gluon_resnet50_v1c gluon_resnet50_v1d gluon_resnet50_v1s gluon_resnet101_v1b gluon_resnet101_v1c gluon_resnet101_v1d gluon_resnet101_v1s gluon_resnet152_v1b gluon_resnet152_v1c gluon_resnet152_v1d gluon_resnet152_v1s inception_resnet_v2 lambda_resnet26rpt_256 lambda_resnet26t lambda_resnet50ts legacy_seresnet18 legacy_seresnet34 legacy_seresnet50 legacy_seresnet101 legacy_seresnet152 nf_ecaresnet26 nf_ecaresnet50 nf_ecaresnet101 nf_resnet26 nf_resnet50 nf_resnet101 nf_seresnet26 nf_seresnet50 nf_seresnet101 resnet10t resnet14t resnet18 resnet18d resnet26 resnet26d resnet26t resnet32ts resnet33ts resnet34 resnet34d resnet50 resnet50_gn resnet50d resnet50t resnet51q resnet61q resnet101 resnet101d resnet152 resnet152d resnet200 resnet200d resnetaa50 resnetaa50d resnetaa101d resnetblur18 resnetblur50 resnetblur50d resnetblur101d resnetrs50 resnetrs101 resnetrs152 resnetrs200 resnetrs270 resnetrs350 resnetrs420 resnetv2_50 resnetv2_50d resnetv2_50d_evob resnetv2_50d_evos resnetv2_50d_frn resnetv2_50d_gn resnetv2_50t resnetv2_50x1_bit_distilled resnetv2_50x1_bitm resnetv2_50x1_bitm_in21k resnetv2_50x3_bitm resnetv2_50x3_bitm_in21k resnetv2_101 resnetv2_101d resnetv2_101x1_bitm resnetv2_101x1_bitm_in21k resnetv2_101x3_bitm resnetv2_101x3_bitm_in21k resnetv2_152 resnetv2_152d resnetv2_152x2_bit_teacher resnetv2_152x2_bit_teacher_384 resnetv2_152x2_bitm resnetv2_152x2_bitm_in21k resnetv2_152x4_bitm resnetv2_152x4_bitm_in21k seresnet18 seresnet33ts seresnet34 seresnet50 seresnet50t seresnet101 seresnet152 seresnet152d seresnet200d seresnet269d seresnetaa50d skresnet18 skresnet34 skresnet50 skresnet50d ssl_resnet18 ssl_resnet50 swsl_resnet18 swsl_resnet50 tresnet_l tresnet_l_448 tresnet_m tresnet_m_448 tresnet_m_miil_in21k tresnet_v2_l tresnet_xl tresnet_xl_448 tv_resnet34 tv_resnet50 tv_resnet101 tv_resnet152 vit_base_resnet26d_224 vit_base_resnet50_224_in21k vit_base_resnet50_384 vit_base_resnet50d_224 vit_small_resnet26d_224 vit_small_resnet50d_s16_224 wide_resnet50_2 wide_resnet101_2'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(timm.list_models('*resnet*'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "resnet18: block=BasicBlock, layers=[2, 2, 2, 2]\n", "resnet18d: block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True\n", "resnet10t: block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = timm.create_model('resnet18d', in_chans=1, num_classes=10)\n", "# model = ResNet(in_chans=1, block=BasicBlock, layers=[2,2,2,2], stem_width=32, avg_down=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2\n", "tmax = epochs * len(dls.train)\n", "sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)\n", "xtra = [BatchSchedCB(sched)]\n", "learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracylossepochtrain
0.7810.6330train
0.6641.3160eval
0.8780.3291train
0.8700.3621eval
0.9050.2552train
0.8890.3072eval
0.9260.1973train
0.9110.2443eval
0.9450.1504train
0.9200.2234eval
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(epochs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import nbdev; nbdev.nbdev_export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }