{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Residual Networks (ResNet)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.067206Z", "start_time": "2019-07-03T22:37:46.213755Z" } }, "outputs": [], "source": [ "import d2l\n", "from mxnet import gluon, np, npx\n", "from mxnet.gluon import nn\n", "npx.set_np()\n", "\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=256, resize=96)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Residual block" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.083212Z", "start_time": "2019-07-03T22:37:50.070948Z" }, "attributes": { "classes": [], "id": "", "n": "1" } }, "outputs": [], "source": [ "class Residual(nn.Block):\n", " def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):\n", " super(Residual, self).__init__(**kwargs)\n", " self.conv1 = nn.Conv2D(num_channels, kernel_size=3, padding=1, strides=strides)\n", " self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)\n", " self.conv3 = None\n", " if use_1x1conv:\n", " self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)\n", " self.bn1 = nn.BatchNorm()\n", " self.bn2 = nn.BatchNorm()\n", "\n", " def forward(self, X):\n", " Y = npx.relu(self.bn1(self.conv1(X)))\n", " Y = self.bn2(self.conv2(Y))\n", " if self.conv3:\n", " X = self.conv3(X)\n", " return npx.relu(Y + X)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "A situation where the input and output are of the same shape." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.119721Z", "start_time": "2019-07-03T22:37:50.085475Z" }, "attributes": { "classes": [], "id": "", "n": "2" } }, "outputs": [ { "data": { "text/plain": [ "(4, 3, 6, 6)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = Residual(3)\n", "blk.initialize()\n", "X = np.random.uniform(size=(4, 3, 6, 6))\n", "blk(X).shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "Halve the output height and width while increasing the number of output channels" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.171045Z", "start_time": "2019-07-03T22:37:50.122531Z" }, "attributes": { "classes": [], "id": "", "n": "3" } }, "outputs": [ { "data": { "text/plain": [ "(4, 6, 3, 3)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "blk = Residual(6, use_1x1conv=True, strides=2)\n", "blk.initialize()\n", "blk(X).shape" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The ResNet block" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.187029Z", "start_time": "2019-07-03T22:37:50.175512Z" }, "attributes": { "classes": [], "id": "", "n": "4" } }, "outputs": [], "source": [ "def resnet_block(num_channels, num_residuals, first_block=False):\n", " blk = nn.Sequential()\n", " for i in range(num_residuals):\n", " if i == 0 and not first_block:\n", " blk.add(Residual(num_channels, use_1x1conv=True, strides=2))\n", " else:\n", " blk.add(Residual(num_channels))\n", " return blk" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:37:50.213229Z", "start_time": "2019-07-03T22:37:50.194294Z" } }, "outputs": [], "source": [ "net = nn.Sequential()\n", "net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),\n", " nn.BatchNorm(), nn.Activation('relu'),\n", " nn.MaxPool2D(pool_size=3, strides=2, padding=1),\n", " resnet_block(64, 2, first_block=True),\n", " resnet_block(128, 2),\n", " resnet_block(256, 2),\n", " resnet_block(512, 2),\n", " nn.GlobalAvgPool2D(), \n", " nn.Dense(10))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2019-07-03T22:40:13.823984Z", "start_time": "2019-07-03T22:37:50.216572Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.011, train acc 0.997, test acc 0.894\n", "12520.9 exampes/sec on gpu(0)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.train_ch5(net, train_iter, test_iter, num_epochs=10, lr=0.05)" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }