{
"cells": [
{
"cell_type": "markdown",
"id": "cf843007",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Residual Networks (ResNet) and ResNeXt\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d6e5d075",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:50.915858Z",
"iopub.status.busy": "2023-08-18T19:50:50.915085Z",
"iopub.status.idle": "2023-08-18T19:50:53.897064Z",
"shell.execute_reply": "2023-08-18T19:50:53.895755Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "db292341",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Residual Blocks"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "35fa7497",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.901535Z",
"iopub.status.busy": "2023-08-18T19:50:53.900638Z",
"iopub.status.idle": "2023-08-18T19:50:53.909065Z",
"shell.execute_reply": "2023-08-18T19:50:53.907927Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class Residual(nn.Module): \n",
" \"\"\"The Residual block of ResNet models.\"\"\"\n",
" def __init__(self, num_channels, use_1x1conv=False, strides=1):\n",
" super().__init__()\n",
" self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1,\n",
" stride=strides)\n",
" self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)\n",
" if use_1x1conv:\n",
" self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1,\n",
" stride=strides)\n",
" else:\n",
" self.conv3 = None\n",
" self.bn1 = nn.LazyBatchNorm2d()\n",
" self.bn2 = nn.LazyBatchNorm2d()\n",
"\n",
" def forward(self, X):\n",
" Y = F.relu(self.bn1(self.conv1(X)))\n",
" Y = self.bn2(self.conv2(Y))\n",
" if self.conv3:\n",
" X = self.conv3(X)\n",
" Y += X\n",
" return F.relu(Y)"
]
},
{
"cell_type": "markdown",
"id": "bbbb596d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"A situation where the input and output are of the same shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2057b8bc",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.913050Z",
"iopub.status.busy": "2023-08-18T19:50:53.912286Z",
"iopub.status.idle": "2023-08-18T19:50:53.955152Z",
"shell.execute_reply": "2023-08-18T19:50:53.953792Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 3, 6, 6])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(3)\n",
"X = torch.randn(4, 3, 6, 6)\n",
"blk(X).shape"
]
},
{
"cell_type": "markdown",
"id": "aa00b7ad",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Halve the output height and width while increasing the number of output channels"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "341c1c55",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.958860Z",
"iopub.status.busy": "2023-08-18T19:50:53.958579Z",
"iopub.status.idle": "2023-08-18T19:50:53.983195Z",
"shell.execute_reply": "2023-08-18T19:50:53.981643Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 6, 3, 3])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(6, use_1x1conv=True, strides=2)\n",
"blk(X).shape"
]
},
{
"cell_type": "markdown",
"id": "b0873698",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"ResNet Model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0019ee3f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.010309Z",
"iopub.status.busy": "2023-08-18T19:50:54.009135Z",
"iopub.status.idle": "2023-08-18T19:50:54.017906Z",
"shell.execute_reply": "2023-08-18T19:50:54.016848Z"
},
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class ResNet(d2l.Classifier):\n",
" def b1(self):\n",
" return nn.Sequential(\n",
" nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n",
" nn.LazyBatchNorm2d(), nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"\n",
"@d2l.add_to_class(ResNet)\n",
"def block(self, num_residuals, num_channels, first_block=False):\n",
" blk = []\n",
" for i in range(num_residuals):\n",
" if i == 0 and not first_block:\n",
" blk.append(Residual(num_channels, use_1x1conv=True, strides=2))\n",
" else:\n",
" blk.append(Residual(num_channels))\n",
" return nn.Sequential(*blk)\n",
"\n",
"@d2l.add_to_class(ResNet)\n",
"def __init__(self, arch, lr=0.1, num_classes=10):\n",
" super(ResNet, self).__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(self.b1())\n",
" for i, b in enumerate(arch):\n",
" self.net.add_module(f'b{i+2}', self.block(*b, first_block=(i==0)))\n",
" self.net.add_module('last', nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n",
" nn.LazyLinear(num_classes)))\n",
" self.net.apply(d2l.init_cnn)"
]
},
{
"cell_type": "markdown",
"id": "9967d7a1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Observe how the input shape changes across different modules in ResNet"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f153f6ed",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.031902Z",
"iopub.status.busy": "2023-08-18T19:50:54.030981Z",
"iopub.status.idle": "2023-08-18T19:50:54.188619Z",
"shell.execute_reply": "2023-08-18T19:50:54.187488Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 128, 12, 12])\n",
"Sequential output shape:\t torch.Size([1, 256, 6, 6])\n",
"Sequential output shape:\t torch.Size([1, 512, 3, 3])\n",
"Sequential output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"class ResNet18(ResNet):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__(((2, 64), (2, 128), (2, 256), (2, 512)),\n",
" lr, num_classes)\n",
"\n",
"ResNet18().layer_summary((1, 1, 96, 96))"
]
},
{
"cell_type": "markdown",
"id": "c5c1f7a1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "61b87bb9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.192632Z",
"iopub.status.busy": "2023-08-18T19:50:54.191821Z",
"iopub.status.idle": "2023-08-18T19:53:34.753784Z",
"shell.execute_reply": "2023-08-18T19:53:34.752565Z"
},
"origin_pos": 36,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"