{ "cells": [ { "cell_type": "markdown", "id": "741aaf07", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Multi-Branch Networks (GoogLeNet)\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "066f8e89", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:12:43.017176Z", "iopub.status.busy": "2023-08-18T20:12:43.016624Z", "iopub.status.idle": "2023-08-18T20:12:46.230431Z", "shell.execute_reply": "2023-08-18T20:12:46.228850Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "e6c8822d", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "Inception Blocks" ] }, { "cell_type": "code", "execution_count": 2, "id": "7deda9e0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:12:46.235999Z", "iopub.status.busy": "2023-08-18T20:12:46.235477Z", "iopub.status.idle": "2023-08-18T20:12:46.245899Z", "shell.execute_reply": "2023-08-18T20:12:46.244271Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class Inception(nn.Module):\n", " def __init__(self, c1, c2, c3, c4, **kwargs):\n", " super(Inception, self).__init__(**kwargs)\n", " self.b1_1 = nn.LazyConv2d(c1, kernel_size=1)\n", " self.b2_1 = nn.LazyConv2d(c2[0], kernel_size=1)\n", " self.b2_2 = nn.LazyConv2d(c2[1], kernel_size=3, padding=1)\n", " self.b3_1 = nn.LazyConv2d(c3[0], kernel_size=1)\n", " self.b3_2 = nn.LazyConv2d(c3[1], kernel_size=5, padding=2)\n", " self.b4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n", " self.b4_2 = nn.LazyConv2d(c4, kernel_size=1)\n", "\n", " def forward(self, x):\n", " b1 = F.relu(self.b1_1(x))\n", " b2 = F.relu(self.b2_2(F.relu(self.b2_1(x))))\n", " b3 = F.relu(self.b3_2(F.relu(self.b3_1(x))))\n", " b4 = F.relu(self.b4_2(self.b4_1(x)))\n", " return torch.cat((b1, b2, b3, b4), dim=1)" ] }, { "cell_type": "markdown", "id": "5f988986", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "GoogLeNet Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "fe47c47d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:12:46.306521Z", "iopub.status.busy": "2023-08-18T20:12:46.305540Z", "iopub.status.idle": "2023-08-18T20:12:46.313597Z", "shell.execute_reply": "2023-08-18T20:12:46.312191Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class GoogleNet(d2l.Classifier):\n", " def b1(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n", " nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(GoogleNet)\n", "def b2(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=1), nn.ReLU(),\n", " nn.LazyConv2d(192, kernel_size=3, padding=1), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(GoogleNet)\n", "def b3(self):\n", " return nn.Sequential(Inception(64, (96, 128), (16, 32), 32),\n", " Inception(128, (128, 192), (32, 96), 64),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(GoogleNet)\n", "def b4(self):\n", " return nn.Sequential(Inception(192, (96, 208), (16, 48), 64),\n", " Inception(160, (112, 224), (24, 64), 64),\n", " Inception(128, (128, 256), (24, 64), 64),\n", " Inception(112, (144, 288), (32, 64), 64),\n", " Inception(256, (160, 320), (32, 128), 128),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", "\n", "@d2l.add_to_class(GoogleNet)\n", "def b5(self):\n", " return nn.Sequential(Inception(256, (160, 320), (32, 128), 128),\n", " Inception(384, (192, 384), (48, 128), 128),\n", " nn.AdaptiveAvgPool2d((1,1)), nn.Flatten())\n", "\n", "@d2l.add_to_class(GoogleNet)\n", "def __init__(self, lr=0.1, num_classes=10):\n", " super(GoogleNet, self).__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(self.b1(), self.b2(), self.b3(), self.b4(),\n", " self.b5(), nn.LazyLinear(num_classes))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "markdown", "id": "a6886825", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Reduce the input height and width from 224 to 96\n", "to have a reasonable training time on Fashion-MNIST" ] }, { "cell_type": "code", "execution_count": 9, "id": "83b695b7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:12:46.317969Z", "iopub.status.busy": "2023-08-18T20:12:46.316995Z", "iopub.status.idle": "2023-08-18T20:12:46.501717Z", "shell.execute_reply": "2023-08-18T20:12:46.500827Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 24, 24])\n", "Sequential output shape:\t torch.Size([1, 192, 12, 12])\n", "Sequential output shape:\t torch.Size([1, 480, 6, 6])\n", "Sequential output shape:\t torch.Size([1, 832, 3, 3])\n", "Sequential output shape:\t torch.Size([1, 1024])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "source": [ "model = GoogleNet().layer_summary((1, 1, 96, 96))" ] }, { "cell_type": "markdown", "id": "ed1f8ec6", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 10, "id": "d52cffee", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:12:46.505565Z", "iopub.status.busy": "2023-08-18T20:12:46.504978Z", "iopub.status.idle": "2023-08-18T20:16:06.908422Z", "shell.execute_reply": "2023-08-18T20:16:06.907202Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T20:16:06.812746\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = GoogleNet(lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "trainer.fit(model, data)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }