{ "cells": [ { "cell_type": "markdown", "id": "43fc4043", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Networks Using Blocks (VGG)\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "89467a5c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:56:46.393555Z", "iopub.status.busy": "2023-08-18T19:56:46.392607Z", "iopub.status.idle": "2023-08-18T19:56:49.756697Z", "shell.execute_reply": "2023-08-18T19:56:49.755534Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "0d8fc06f", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "VGG Blocks" ] }, { "cell_type": "code", "execution_count": 2, "id": "7a35971a", "metadata": { "attributes": { "classes": [], "id": "", "n": "3" }, "execution": { "iopub.execute_input": "2023-08-18T19:56:49.762934Z", "iopub.status.busy": "2023-08-18T19:56:49.761989Z", "iopub.status.idle": "2023-08-18T19:56:49.770418Z", "shell.execute_reply": "2023-08-18T19:56:49.769006Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def vgg_block(num_convs, out_channels):\n", " layers = []\n", " for _ in range(num_convs):\n", " layers.append(nn.LazyConv2d(out_channels, kernel_size=3, padding=1))\n", " layers.append(nn.ReLU())\n", " layers.append(nn.MaxPool2d(kernel_size=2,stride=2))\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "markdown", "id": "bf32189a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "VGG Network" ] }, { "cell_type": "code", "execution_count": 4, "id": "fc110465", "metadata": { "attributes": { "classes": [], "id": "", "n": "6" }, "execution": { "iopub.execute_input": "2023-08-18T19:56:49.789248Z", "iopub.status.busy": "2023-08-18T19:56:49.788373Z", "iopub.status.idle": "2023-08-18T19:56:51.334656Z", "shell.execute_reply": "2023-08-18T19:56:51.333439Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 64, 112, 112])\n", "Sequential output shape:\t torch.Size([1, 128, 56, 56])\n", "Sequential output shape:\t torch.Size([1, 256, 28, 28])\n", "Sequential output shape:\t torch.Size([1, 512, 14, 14])\n", "Sequential output shape:\t torch.Size([1, 512, 7, 7])\n", "Flatten output shape:\t torch.Size([1, 25088])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Linear output shape:\t torch.Size([1, 4096])\n", "ReLU output shape:\t torch.Size([1, 4096])\n", "Dropout output shape:\t torch.Size([1, 4096])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Linear output shape:\t torch.Size([1, 4096])\n", "ReLU output shape:\t torch.Size([1, 4096])\n", "Dropout output shape:\t torch.Size([1, 4096])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "source": [ "class VGG(d2l.Classifier):\n", " def __init__(self, arch, lr=0.1, num_classes=10):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " conv_blks = []\n", " for (num_convs, out_channels) in arch:\n", " conv_blks.append(vgg_block(num_convs, out_channels))\n", " self.net = nn.Sequential(\n", " *conv_blks, nn.Flatten(),\n", " nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),\n", " nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),\n", " nn.LazyLinear(num_classes))\n", " self.net.apply(d2l.init_cnn)\n", "\n", "VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary(\n", " (1, 1, 224, 224))" ] }, { "cell_type": "markdown", "id": "64c9f7de", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Since VGG-11 is computationally more demanding than AlexNet\n", "we construct a network with a smaller number of channels.\n", "Model training" ] }, { "cell_type": "code", "execution_count": 5, "id": "ed532028", "metadata": { "attributes": { "classes": [], "id": "", "n": "8" }, "execution": { "iopub.execute_input": "2023-08-18T19:56:51.339409Z", "iopub.status.busy": "2023-08-18T19:56:51.338575Z", "iopub.status.idle": "2023-08-18T20:01:43.617050Z", "shell.execute_reply": "2023-08-18T20:01:43.615832Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T20:01:43.519141\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01)\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "trainer.fit(model, data)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }