{ "cells": [ { "cell_type": "markdown", "id": "a1065370", "metadata": {}, "source": [ "The following additional libraries are needed to run this\n", "notebook. Note that running on Colab is experimental, please report a Github\n", "issue if you have any problem." ] }, { "cell_type": "code", "execution_count": null, "id": "3687a475", "metadata": {}, "outputs": [], "source": [ "!pip install d2l==0.17.6\n" ] }, { "cell_type": "markdown", "id": "2b34f4d6", "metadata": { "origin_pos": 0 }, "source": [ "# Network in Network (NiN)\n", ":label:`sec_nin`\n", "\n", "LeNet, AlexNet, and VGG all share a common design pattern:\n", "extract features exploiting *spatial* structure\n", "via a sequence of convolution and pooling layers\n", "and then post-process the representations via fully-connected layers.\n", "The improvements upon LeNet by AlexNet and VGG mainly lie\n", "in how these later networks widen and deepen these two modules.\n", "Alternatively, one could imagine using fully-connected layers\n", "earlier in the process.\n", "However, a careless use of dense layers might give up the\n", "spatial structure of the representation entirely,\n", "*network in network* (*NiN*) blocks offer an alternative.\n", "They were proposed based on a very simple insight:\n", "to use an MLP on the channels for each pixel separately :cite:`Lin.Chen.Yan.2013`.\n", "\n", "\n", "## (**NiN Blocks**)\n", "\n", "Recall that the inputs and outputs of convolutional layers\n", "consist of four-dimensional tensors with axes\n", "corresponding to the example, channel, height, and width.\n", "Also recall that the inputs and outputs of fully-connected layers\n", "are typically two-dimensional tensors corresponding to the example and feature.\n", "The idea behind NiN is to apply a fully-connected layer\n", "at each pixel location (for each height and width).\n", "If we tie the weights across each spatial location,\n", "we could think of this as a $1\\times 1$ convolutional layer\n", "(as described in :numref:`sec_channels`)\n", "or as a fully-connected layer acting independently on each pixel location.\n", "Another way to view this is to think of each element in the spatial dimension\n", "(height and width) as equivalent to an example\n", "and a channel as equivalent to a feature.\n", "\n", ":numref:`fig_nin` illustrates the main structural differences\n", "between VGG and NiN, and their blocks.\n", "The NiN block consists of one convolutional layer\n", "followed by two $1\\times 1$ convolutional layers that act as\n", "per-pixel fully-connected layers with ReLU activations.\n", "The convolution window shape of the first layer is typically set by the user.\n", "The subsequent window shapes are fixed to $1 \\times 1$.\n", "\n", "![Comparing architectures of VGG and NiN, and their blocks.](http://d2l.ai/_images/nin.svg)\n", ":width:`600px`\n", ":label:`fig_nin`\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "f1867fd9", "metadata": { "execution": { "iopub.execute_input": "2022-11-12T21:02:14.658780Z", "iopub.status.busy": "2022-11-12T21:02:14.657866Z", "iopub.status.idle": "2022-11-12T21:02:17.700392Z", "shell.execute_reply": "2022-11-12T21:02:17.699520Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "\n", "def nin_block(in_channels, out_channels, kernel_size, strides, padding):\n", " return nn.Sequential(\n", " nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),\n", " nn.ReLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),\n", " nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())" ] }, { "cell_type": "markdown", "id": "47e1c42d", "metadata": { "origin_pos": 4 }, "source": [ "## [**NiN Model**]\n", "\n", "The original NiN network was proposed shortly after AlexNet\n", "and clearly draws some inspiration.\n", "NiN uses convolutional layers with window shapes\n", "of $11\\times 11$, $5\\times 5$, and $3\\times 3$,\n", "and the corresponding numbers of output channels are the same as in AlexNet. Each NiN block is followed by a maximum pooling layer\n", "with a stride of 2 and a window shape of $3\\times 3$.\n", "\n", "One significant difference between NiN and AlexNet\n", "is that NiN avoids fully-connected layers altogether.\n", "Instead, NiN uses an NiN block with a number of output channels equal to the number of label classes, followed by a *global* average pooling layer,\n", "yielding a vector of logits.\n", "One advantage of NiN's design is that it significantly\n", "reduces the number of required model parameters.\n", "However, in practice, this design sometimes requires\n", "increased model training time.\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "4205e242", "metadata": { "execution": { "iopub.execute_input": "2022-11-12T21:02:17.705805Z", "iopub.status.busy": "2022-11-12T21:02:17.705415Z", "iopub.status.idle": "2022-11-12T21:02:17.751625Z", "shell.execute_reply": "2022-11-12T21:02:17.750711Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = nn.Sequential(\n", " nin_block(1, 96, kernel_size=11, strides=4, padding=0),\n", " nn.MaxPool2d(3, stride=2),\n", " nin_block(96, 256, kernel_size=5, strides=1, padding=2),\n", " nn.MaxPool2d(3, stride=2),\n", " nin_block(256, 384, kernel_size=3, strides=1, padding=1),\n", " nn.MaxPool2d(3, stride=2),\n", " nn.Dropout(0.5),\n", " # There are 10 label classes\n", " nin_block(384, 10, kernel_size=3, strides=1, padding=1),\n", " nn.AdaptiveAvgPool2d((1, 1)),\n", " # Transform the four-dimensional output into two-dimensional output with a\n", " # shape of (batch size, 10)\n", " nn.Flatten())" ] }, { "cell_type": "markdown", "id": "47238c57", "metadata": { "origin_pos": 8 }, "source": [ "We create a data example to see [**the output shape of each block**].\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "0076c3b2", "metadata": { "execution": { "iopub.execute_input": "2022-11-12T21:02:17.758635Z", "iopub.status.busy": "2022-11-12T21:02:17.757918Z", "iopub.status.idle": "2022-11-12T21:02:17.783091Z", "shell.execute_reply": "2022-11-12T21:02:17.782106Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential output shape:\t torch.Size([1, 96, 54, 54])\n", "MaxPool2d output shape:\t torch.Size([1, 96, 26, 26])\n", "Sequential output shape:\t torch.Size([1, 256, 26, 26])\n", "MaxPool2d output shape:\t torch.Size([1, 256, 12, 12])\n", "Sequential output shape:\t torch.Size([1, 384, 12, 12])\n", "MaxPool2d output shape:\t torch.Size([1, 384, 5, 5])\n", "Dropout output shape:\t torch.Size([1, 384, 5, 5])\n", "Sequential output shape:\t torch.Size([1, 10, 5, 5])\n", "AdaptiveAvgPool2d output shape:\t torch.Size([1, 10, 1, 1])\n", "Flatten output shape:\t torch.Size([1, 10])\n" ] } ], "source": [ "X = torch.rand(size=(1, 1, 224, 224))\n", "for layer in net:\n", " X = layer(X)\n", " print(layer.__class__.__name__,'output shape:\\t', X.shape)" ] }, { "cell_type": "markdown", "id": "e0883921", "metadata": { "origin_pos": 12 }, "source": [ "## [**Training**]\n", "\n", "As before we use Fashion-MNIST to train the model.\n", "NiN's training is similar to that for AlexNet and VGG.\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "00d331f5", "metadata": { "execution": { "iopub.execute_input": "2022-11-12T21:02:17.786721Z", "iopub.status.busy": "2022-11-12T21:02:17.786369Z", "iopub.status.idle": "2022-11-12T21:06:35.956154Z", "shell.execute_reply": "2022-11-12T21:06:35.954701Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.336, train acc 0.874, test acc 0.867\n", "3192.8 examples/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2022-11-12T21:06:35.912252\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "lr, num_epochs, batch_size = 0.1, 10, 128\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n", "d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ] }, { "cell_type": "markdown", "id": "c1cd643e", "metadata": { "origin_pos": 14 }, "source": [ "## Summary\n", "\n", "* NiN uses blocks consisting of a convolutional layer and multiple $1\\times 1$ convolutional layers. This can be used within the convolutional stack to allow for more per-pixel nonlinearity.\n", "* NiN removes the fully-connected layers and replaces them with global average pooling (i.e., summing over all locations) after reducing the number of channels to the desired number of outputs (e.g., 10 for Fashion-MNIST).\n", "* Removing the fully-connected layers reduces overfitting. NiN has dramatically fewer parameters.\n", "* The NiN design influenced many subsequent CNN designs.\n", "\n", "## Exercises\n", "\n", "1. Tune the hyperparameters to improve the classification accuracy.\n", "1. Why are there two $1\\times 1$ convolutional layers in the NiN block? Remove one of them, and then observe and analyze the experimental phenomena.\n", "1. Calculate the resource usage for NiN.\n", " 1. What is the number of parameters?\n", " 1. What is the amount of computation?\n", " 1. What is the amount of memory needed during training?\n", " 1. What is the amount of memory needed during prediction?\n", "1. What are possible problems with reducing the $384 \\times 5 \\times 5$ representation to a $10 \\times 5 \\times 5$ representation in one step?\n" ] }, { "cell_type": "markdown", "id": "c9ecbd66", "metadata": { "origin_pos": 16, "tab": [ "pytorch" ] }, "source": [ "[Discussions](https://discuss.d2l.ai/t/80)\n" ] } ], "metadata": { "accelerator": "GPU", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }