{ "cells": [ { "cell_type": "markdown", "id": "f4cf78e1", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Densely Connected Networks (DenseNet)\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "898d7836", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:25.812288Z", "iopub.status.busy": "2023-08-18T19:47:25.811221Z", "iopub.status.idle": "2023-08-18T19:47:28.844635Z", "shell.execute_reply": "2023-08-18T19:47:28.843703Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "dcb6db8a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Dense Blocks" ] }, { "cell_type": "code", "execution_count": 3, "id": "805394e3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.856547Z", "iopub.status.busy": "2023-08-18T19:47:28.855981Z", "iopub.status.idle": "2023-08-18T19:47:28.862956Z", "shell.execute_reply": "2023-08-18T19:47:28.861783Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def conv_block(num_channels):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=3, padding=1))\n", "\n", "class DenseBlock(nn.Module):\n", " def __init__(self, num_convs, num_channels):\n", " super(DenseBlock, self).__init__()\n", " layer = []\n", " for i in range(num_convs):\n", " layer.append(conv_block(num_channels))\n", " self.net = nn.Sequential(*layer)\n", "\n", " def forward(self, X):\n", " for blk in self.net:\n", " Y = blk(X)\n", " X = torch.cat((X, Y), dim=1)\n", " return X" ] }, { "cell_type": "markdown", "id": "ff6dae74", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Define a `DenseBlock` instance" ] }, { "cell_type": "code", "execution_count": 4, "id": "e369c1ad", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.867108Z", "iopub.status.busy": "2023-08-18T19:47:28.866407Z", "iopub.status.idle": "2023-08-18T19:47:28.909936Z", "shell.execute_reply": "2023-08-18T19:47:28.908954Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 23, 8, 8])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = DenseBlock(2, 10)\n", "X = torch.randn(4, 3, 8, 8)\n", "Y = blk(X)\n", "Y.shape" ] }, { "cell_type": "markdown", "id": "dc32fbb1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Transition Layers" ] }, { "cell_type": "code", "execution_count": 5, "id": "6160cc48", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.915937Z", "iopub.status.busy": "2023-08-18T19:47:28.914796Z", "iopub.status.idle": "2023-08-18T19:47:28.920281Z", "shell.execute_reply": "2023-08-18T19:47:28.919184Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def transition_block(num_channels):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=1),\n", " nn.AvgPool2d(kernel_size=2, stride=2))" ] }, { "cell_type": "markdown", "id": "9f88387c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Apply a transition layer" ] }, { "cell_type": "code", "execution_count": 6, "id": "fc0cacfc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.924597Z", "iopub.status.busy": "2023-08-18T19:47:28.924323Z", "iopub.status.idle": "2023-08-18T19:47:28.938373Z", "shell.execute_reply": "2023-08-18T19:47:28.937285Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 10, 4, 4])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = transition_block(10)\n", "blk(Y).shape" ] }, { "cell_type": "markdown", "id": "e0ea910e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "DenseNet Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "58137883", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.950027Z", "iopub.status.busy": "2023-08-18T19:47:28.949746Z", "iopub.status.idle": "2023-08-18T19:47:28.958660Z", "shell.execute_reply": "2023-08-18T19:47:28.957444Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class DenseNet(d2l.Classifier):\n", " def b1(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(DenseNet)\n", "def __init__(self, num_channels=64, growth_rate=32, arch=(4, 4, 4, 4),\n", " lr=0.1, num_classes=10):\n", " super(DenseNet, self).__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(self.b1())\n", " for i, num_convs in enumerate(arch):\n", " self.net.add_module(f'dense_blk{i+1}', DenseBlock(num_convs,\n", " growth_rate))\n", " num_channels += num_convs * growth_rate\n", " if i != len(arch) - 1:\n", " num_channels //= 2\n", " self.net.add_module(f'tran_blk{i+1}', transition_block(\n", " num_channels))\n", " self.net.add_module('last', nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n", " nn.LazyLinear(num_classes)))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "markdown", "id": "ab4d03b5", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 9, "id": "ef87c44e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:47:28.963624Z", "iopub.status.busy": "2023-08-18T19:47:28.962964Z", "iopub.status.idle": "2023-08-18T19:50:01.060105Z", "shell.execute_reply": "2023-08-18T19:50:01.059052Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:50:00.878365\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = DenseNet(lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))\n", "trainer.fit(model, data)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }