{ "cells": [ { "cell_type": "markdown", "id": "cf843007", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Residual Networks (ResNet) and ResNeXt\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "d6e5d075", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:50.915858Z", "iopub.status.busy": "2023-08-18T19:50:50.915085Z", "iopub.status.idle": "2023-08-18T19:50:53.897064Z", "shell.execute_reply": "2023-08-18T19:50:53.895755Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "db292341", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "Residual Blocks" ] }, { "cell_type": "code", "execution_count": 2, "id": "35fa7497", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:53.901535Z", "iopub.status.busy": "2023-08-18T19:50:53.900638Z", "iopub.status.idle": "2023-08-18T19:50:53.909065Z", "shell.execute_reply": "2023-08-18T19:50:53.907927Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Residual(nn.Module): \n", " \"\"\"The Residual block of ResNet models.\"\"\"\n", " def __init__(self, num_channels, use_1x1conv=False, strides=1):\n", " super().__init__()\n", " self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1,\n", " stride=strides)\n", " self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)\n", " if use_1x1conv:\n", " self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1,\n", " stride=strides)\n", " else:\n", " self.conv3 = None\n", " self.bn1 = nn.LazyBatchNorm2d()\n", " self.bn2 = nn.LazyBatchNorm2d()\n", "\n", " def forward(self, X):\n", " Y = F.relu(self.bn1(self.conv1(X)))\n", " Y = self.bn2(self.conv2(Y))\n", " if self.conv3:\n", " X = self.conv3(X)\n", " Y += X\n", " return F.relu(Y)" ] }, { "cell_type": "markdown", "id": "bbbb596d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "A situation where the input and output are of the same shape" ] }, { "cell_type": "code", "execution_count": 3, "id": "2057b8bc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:53.913050Z", "iopub.status.busy": "2023-08-18T19:50:53.912286Z", "iopub.status.idle": "2023-08-18T19:50:53.955152Z", "shell.execute_reply": "2023-08-18T19:50:53.953792Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 3, 6, 6])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = Residual(3)\n", "X = torch.randn(4, 3, 6, 6)\n", "blk(X).shape" ] }, { "cell_type": "markdown", "id": "aa00b7ad", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Halve the output height and width while increasing the number of output channels" ] }, { "cell_type": "code", "execution_count": 4, "id": "341c1c55", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:53.958860Z", "iopub.status.busy": "2023-08-18T19:50:53.958579Z", "iopub.status.idle": "2023-08-18T19:50:53.983195Z", "shell.execute_reply": "2023-08-18T19:50:53.981643Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 6, 3, 3])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = Residual(6, use_1x1conv=True, strides=2)\n", "blk(X).shape" ] }, { "cell_type": "markdown", "id": "b0873698", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "ResNet Model" ] }, { "cell_type": "code", "execution_count": 7, "id": "0019ee3f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:54.010309Z", "iopub.status.busy": "2023-08-18T19:50:54.009135Z", "iopub.status.idle": "2023-08-18T19:50:54.017906Z", "shell.execute_reply": "2023-08-18T19:50:54.016848Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class ResNet(d2l.Classifier):\n", " def b1(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(ResNet)\n", "def block(self, num_residuals, num_channels, first_block=False):\n", " blk = []\n", " for i in range(num_residuals):\n", " if i == 0 and not first_block:\n", " blk.append(Residual(num_channels, use_1x1conv=True, strides=2))\n", " else:\n", " blk.append(Residual(num_channels))\n", " return nn.Sequential(*blk)\n", "\n", "@d2l.add_to_class(ResNet)\n", "def __init__(self, arch, lr=0.1, num_classes=10):\n", " super(ResNet, self).__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(self.b1())\n", " for i, b in enumerate(arch):\n", " self.net.add_module(f'b{i+2}', self.block(*b, first_block=(i==0)))\n", " self.net.add_module('last', nn.Sequential(\n", " nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n", " nn.LazyLinear(num_classes)))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "markdown", "id": "9967d7a1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Observe how the input shape changes across different modules in ResNet" ] }, { "cell_type": "code", "execution_count": 9, "id": "f153f6ed", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:54.031902Z", "iopub.status.busy": "2023-08-18T19:50:54.030981Z", "iopub.status.idle": "2023-08-18T19:50:54.188619Z", "shell.execute_reply": "2023-08-18T19:50:54.187488Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 24, 24])\n", "Sequential output shape:\t torch.Size([1, 64, 24, 24])\n", "Sequential output shape:\t torch.Size([1, 128, 12, 12])\n", "Sequential output shape:\t torch.Size([1, 256, 6, 6])\n", "Sequential output shape:\t torch.Size([1, 512, 3, 3])\n", "Sequential output shape:\t torch.Size([1, 10])\n" ] } ], "source": [ "class ResNet18(ResNet):\n", " def __init__(self, lr=0.1, num_classes=10):\n", " super().__init__(((2, 64), (2, 128), (2, 256), (2, 512)),\n", " lr, num_classes)\n", "\n", "ResNet18().layer_summary((1, 1, 96, 96))" ] }, { "cell_type": "markdown", "id": "c5c1f7a1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 10, "id": "61b87bb9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:54.192632Z", "iopub.status.busy": "2023-08-18T19:50:54.191821Z", "iopub.status.idle": "2023-08-18T19:53:34.753784Z", "shell.execute_reply": "2023-08-18T19:53:34.752565Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:53:34.579212\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = ResNet18(lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "trainer.fit(model, data)" ] }, { "cell_type": "code", "execution_count": 12, "id": "ec6906e4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:53:34.771844Z", "iopub.status.busy": "2023-08-18T19:53:34.771565Z", "iopub.status.idle": "2023-08-18T19:53:34.803609Z", "shell.execute_reply": "2023-08-18T19:53:34.802407Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 32, 96, 96])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class ResNeXtBlock(nn.Module): \n", " \"\"\"The ResNeXt block.\"\"\"\n", " def __init__(self, num_channels, groups, bot_mul, use_1x1conv=False,\n", " strides=1):\n", " super().__init__()\n", " bot_channels = int(round(num_channels * bot_mul))\n", " self.conv1 = nn.LazyConv2d(bot_channels, kernel_size=1, stride=1)\n", " self.conv2 = nn.LazyConv2d(bot_channels, kernel_size=3,\n", " stride=strides, padding=1,\n", " groups=bot_channels//groups)\n", " self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=1)\n", " self.bn1 = nn.LazyBatchNorm2d()\n", " self.bn2 = nn.LazyBatchNorm2d()\n", " self.bn3 = nn.LazyBatchNorm2d()\n", " if use_1x1conv:\n", " self.conv4 = nn.LazyConv2d(num_channels, kernel_size=1,\n", " stride=strides)\n", " self.bn4 = nn.LazyBatchNorm2d()\n", " else:\n", " self.conv4 = None\n", "\n", " def forward(self, X):\n", " Y = F.relu(self.bn1(self.conv1(X)))\n", " Y = F.relu(self.bn2(self.conv2(Y)))\n", " Y = self.bn3(self.conv3(Y))\n", " if self.conv4:\n", " X = self.bn4(self.conv4(X))\n", " return F.relu(Y + X)\n", "\n", "blk = ResNeXtBlock(32, 16, 1)\n", "X = torch.randn(4, 32, 96, 96)\n", "blk(X).shape" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }