{ "cells": [ { "cell_type": "markdown", "id": "42454a9e", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Deep Convolutional Neural Networks (AlexNet)\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "29feac8e", "metadata": { "attributes": { "classes": [], "id": "", "n": "5" }, "execution": { "iopub.execute_input": "2023-08-18T19:43:36.252627Z", "iopub.status.busy": "2023-08-18T19:43:36.251926Z", "iopub.status.idle": "2023-08-18T19:43:36.259662Z", "shell.execute_reply": "2023-08-18T19:43:36.258841Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "class AlexNet(d2l.Classifier):\n", " def __init__(self, lr=0.1, num_classes=10):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(\n", " nn.LazyConv2d(96, kernel_size=11, stride=4, padding=1),\n", " nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.LazyConv2d(256, kernel_size=5, padding=2), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2),\n", " nn.LazyConv2d(384, kernel_size=3, padding=1), nn.ReLU(),\n", " nn.LazyConv2d(384, kernel_size=3, padding=1), nn.ReLU(),\n", " nn.LazyConv2d(256, kernel_size=3, padding=1), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2), nn.Flatten(),\n", " nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(p=0.5),\n", " nn.LazyLinear(4096), nn.ReLU(),nn.Dropout(p=0.5),\n", " nn.LazyLinear(num_classes))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "markdown", "id": "53d3c554", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Construct a single-channel data example\n", "to observe the output shape of each layer" ] }, { "cell_type": "code", "execution_count": 3, "id": "3d5c2c0a", "metadata": { "attributes": { "classes": [], "id": "", "n": "6" }, "execution": { "iopub.execute_input": "2023-08-18T19:43:36.262984Z", "iopub.status.busy": "2023-08-18T19:43:36.262447Z", "iopub.status.idle": "2023-08-18T19:43:36.786362Z", "shell.execute_reply": "2023-08-18T19:43:36.785437Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Conv2d output shape:\t torch.Size([1, 96, 54, 54])\n", "ReLU output shape:\t torch.Size([1, 96, 54, 54])\n", "MaxPool2d output shape:\t torch.Size([1, 96, 26, 26])\n", "Conv2d output shape:\t torch.Size([1, 256, 26, 26])\n", "ReLU output shape:\t torch.Size([1, 256, 26, 26])\n", "MaxPool2d output shape:\t torch.Size([1, 256, 12, 12])\n", "Conv2d output shape:\t torch.Size([1, 384, 12, 12])\n", "ReLU output shape:\t torch.Size([1, 384, 12, 12])\n", "Conv2d output shape:\t torch.Size([1, 384, 12, 12])\n", "ReLU output shape:\t torch.Size([1, 384, 12, 12])\n", "Conv2d output shape:\t torch.Size([1, 256, 12, 12])\n", "ReLU output shape:\t torch.Size([1, 256, 12, 12])\n", "MaxPool2d output shape:\t torch.Size([1, 256, 5, 5])\n", "Flatten output shape:\t torch.Size([1, 6400])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Linear output shape:\t torch.Size([1, 4096])\n", "ReLU output shape:\t torch.Size([1, 4096])\n", "Dropout output shape:\t torch.Size([1, 4096])\n", "Linear output shape:\t torch.Size([1, 4096])\n", "ReLU output shape:\t torch.Size([1, 4096])\n", "Dropout output shape:\t torch.Size([1, 4096])\n", "Linear output shape:\t torch.Size([1, 10])\n" ] } ], "source": [ "AlexNet().layer_summary((1, 1, 224, 224))" ] }, { "cell_type": "markdown", "id": "d155463a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Fashion-MNIST\n", "images have lower resolution\n", "than ImageNet images.\n", "We upsample them to $224 \\times 224$\n", "start training AlexNet" ] }, { "cell_type": "code", "execution_count": 4, "id": "acd6a3e0", "metadata": { "attributes": { "classes": [], "id": "", "n": "8" }, "execution": { "iopub.execute_input": "2023-08-18T19:43:36.789914Z", "iopub.status.busy": "2023-08-18T19:43:36.789357Z", "iopub.status.idle": "2023-08-18T19:46:37.458518Z", "shell.execute_reply": "2023-08-18T19:46:37.457483Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:46:37.358909\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = AlexNet(lr=0.01)\n", "data = d2l.FashionMNIST(batch_size=128, resize=(224, 224))\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "trainer.fit(model, data)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }