{
"cells": [
{
"cell_type": "markdown",
"id": "cf843007",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Residual Networks (ResNet) and ResNeXt\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d6e5d075",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:50.915858Z",
"iopub.status.busy": "2023-08-18T19:50:50.915085Z",
"iopub.status.idle": "2023-08-18T19:50:53.897064Z",
"shell.execute_reply": "2023-08-18T19:50:53.895755Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "db292341",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Residual Blocks"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "35fa7497",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.901535Z",
"iopub.status.busy": "2023-08-18T19:50:53.900638Z",
"iopub.status.idle": "2023-08-18T19:50:53.909065Z",
"shell.execute_reply": "2023-08-18T19:50:53.907927Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class Residual(nn.Module): \n",
" \"\"\"The Residual block of ResNet models.\"\"\"\n",
" def __init__(self, num_channels, use_1x1conv=False, strides=1):\n",
" super().__init__()\n",
" self.conv1 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1,\n",
" stride=strides)\n",
" self.conv2 = nn.LazyConv2d(num_channels, kernel_size=3, padding=1)\n",
" if use_1x1conv:\n",
" self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1,\n",
" stride=strides)\n",
" else:\n",
" self.conv3 = None\n",
" self.bn1 = nn.LazyBatchNorm2d()\n",
" self.bn2 = nn.LazyBatchNorm2d()\n",
"\n",
" def forward(self, X):\n",
" Y = F.relu(self.bn1(self.conv1(X)))\n",
" Y = self.bn2(self.conv2(Y))\n",
" if self.conv3:\n",
" X = self.conv3(X)\n",
" Y += X\n",
" return F.relu(Y)"
]
},
{
"cell_type": "markdown",
"id": "bbbb596d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"A situation where the input and output are of the same shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2057b8bc",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.913050Z",
"iopub.status.busy": "2023-08-18T19:50:53.912286Z",
"iopub.status.idle": "2023-08-18T19:50:53.955152Z",
"shell.execute_reply": "2023-08-18T19:50:53.953792Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 3, 6, 6])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(3)\n",
"X = torch.randn(4, 3, 6, 6)\n",
"blk(X).shape"
]
},
{
"cell_type": "markdown",
"id": "aa00b7ad",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Halve the output height and width while increasing the number of output channels"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "341c1c55",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:53.958860Z",
"iopub.status.busy": "2023-08-18T19:50:53.958579Z",
"iopub.status.idle": "2023-08-18T19:50:53.983195Z",
"shell.execute_reply": "2023-08-18T19:50:53.981643Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 6, 3, 3])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"blk = Residual(6, use_1x1conv=True, strides=2)\n",
"blk(X).shape"
]
},
{
"cell_type": "markdown",
"id": "b0873698",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"ResNet Model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0019ee3f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.010309Z",
"iopub.status.busy": "2023-08-18T19:50:54.009135Z",
"iopub.status.idle": "2023-08-18T19:50:54.017906Z",
"shell.execute_reply": "2023-08-18T19:50:54.016848Z"
},
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class ResNet(d2l.Classifier):\n",
" def b1(self):\n",
" return nn.Sequential(\n",
" nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n",
" nn.LazyBatchNorm2d(), nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n",
"\n",
"@d2l.add_to_class(ResNet)\n",
"def block(self, num_residuals, num_channels, first_block=False):\n",
" blk = []\n",
" for i in range(num_residuals):\n",
" if i == 0 and not first_block:\n",
" blk.append(Residual(num_channels, use_1x1conv=True, strides=2))\n",
" else:\n",
" blk.append(Residual(num_channels))\n",
" return nn.Sequential(*blk)\n",
"\n",
"@d2l.add_to_class(ResNet)\n",
"def __init__(self, arch, lr=0.1, num_classes=10):\n",
" super(ResNet, self).__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(self.b1())\n",
" for i, b in enumerate(arch):\n",
" self.net.add_module(f'b{i+2}', self.block(*b, first_block=(i==0)))\n",
" self.net.add_module('last', nn.Sequential(\n",
" nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n",
" nn.LazyLinear(num_classes)))\n",
" self.net.apply(d2l.init_cnn)"
]
},
{
"cell_type": "markdown",
"id": "9967d7a1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Observe how the input shape changes across different modules in ResNet"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f153f6ed",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.031902Z",
"iopub.status.busy": "2023-08-18T19:50:54.030981Z",
"iopub.status.idle": "2023-08-18T19:50:54.188619Z",
"shell.execute_reply": "2023-08-18T19:50:54.187488Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 64, 24, 24])\n",
"Sequential output shape:\t torch.Size([1, 128, 12, 12])\n",
"Sequential output shape:\t torch.Size([1, 256, 6, 6])\n",
"Sequential output shape:\t torch.Size([1, 512, 3, 3])\n",
"Sequential output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"class ResNet18(ResNet):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__(((2, 64), (2, 128), (2, 256), (2, 512)),\n",
" lr, num_classes)\n",
"\n",
"ResNet18().layer_summary((1, 1, 96, 96))"
]
},
{
"cell_type": "markdown",
"id": "c5c1f7a1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "61b87bb9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:50:54.192632Z",
"iopub.status.busy": "2023-08-18T19:50:54.191821Z",
"iopub.status.idle": "2023-08-18T19:53:34.753784Z",
"shell.execute_reply": "2023-08-18T19:53:34.752565Z"
},
"origin_pos": 36,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = ResNet18(lr=0.01)\n",
"trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n",
"data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))\n",
"model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ec6906e4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:34.771844Z",
"iopub.status.busy": "2023-08-18T19:53:34.771565Z",
"iopub.status.idle": "2023-08-18T19:53:34.803609Z",
"shell.execute_reply": "2023-08-18T19:53:34.802407Z"
},
"origin_pos": 44,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 32, 96, 96])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class ResNeXtBlock(nn.Module): \n",
" \"\"\"The ResNeXt block.\"\"\"\n",
" def __init__(self, num_channels, groups, bot_mul, use_1x1conv=False,\n",
" strides=1):\n",
" super().__init__()\n",
" bot_channels = int(round(num_channels * bot_mul))\n",
" self.conv1 = nn.LazyConv2d(bot_channels, kernel_size=1, stride=1)\n",
" self.conv2 = nn.LazyConv2d(bot_channels, kernel_size=3,\n",
" stride=strides, padding=1,\n",
" groups=bot_channels//groups)\n",
" self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=1)\n",
" self.bn1 = nn.LazyBatchNorm2d()\n",
" self.bn2 = nn.LazyBatchNorm2d()\n",
" self.bn3 = nn.LazyBatchNorm2d()\n",
" if use_1x1conv:\n",
" self.conv4 = nn.LazyConv2d(num_channels, kernel_size=1,\n",
" stride=strides)\n",
" self.bn4 = nn.LazyBatchNorm2d()\n",
" else:\n",
" self.conv4 = None\n",
"\n",
" def forward(self, X):\n",
" Y = F.relu(self.bn1(self.conv1(X)))\n",
" Y = F.relu(self.bn2(self.conv2(Y)))\n",
" Y = self.bn3(self.conv3(Y))\n",
" if self.conv4:\n",
" X = self.bn4(self.conv4(X))\n",
" return F.relu(Y + X)\n",
"\n",
"blk = ResNeXtBlock(32, 16, 1)\n",
"X = torch.randn(4, 32, 96, 96)\n",
"blk(X).shape"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}