{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 含并行连结的网络(GoogLeNet) "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-08-29T00:46:23.968827Z",
"start_time": "2019-08-29T00:46:22.660149Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"import d2l\n",
"from mxnet import gluon, np, npx\n",
"from mxnet.gluon import nn\n",
"npx.set_np()\n",
"\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=128, resize=96)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Inception 块。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.592275Z",
"start_time": "2019-07-07T06:08:45.583648Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "1"
},
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"class Inception(nn.Block):\n",
" # c1 - c4 为每条线路里的层的输出通道数\n",
" def __init__(self, c1, c2, c3, c4, **kwargs):\n",
" super(Inception, self).__init__(**kwargs)\n",
" # 线路1,单1 x 1卷积层\n",
" self.p1_1 = nn.Conv2D(c1, kernel_size=1, activation='relu')\n",
" # 线路2,1 x 1卷积层后接3 x 3卷积层\n",
" self.p2_1 = nn.Conv2D(c2[0], kernel_size=1, activation='relu')\n",
" self.p2_2 = nn.Conv2D(c2[1], kernel_size=3, padding=1, activation='relu')\n",
" # 线路3,1 x 1卷积层后接5 x 5卷积层\n",
" self.p3_1 = nn.Conv2D(c3[0], kernel_size=1, activation='relu')\n",
" self.p3_2 = nn.Conv2D(c3[1], kernel_size=5, padding=2, activation='relu')\n",
" # 线路4,3 x 3最大池化层后接1 x 1卷积层\n",
" self.p4_1 = nn.MaxPool2D(pool_size=3, strides=1, padding=1)\n",
" self.p4_2 = nn.Conv2D(c4, kernel_size=1, activation='relu')\n",
" def forward(self, x):\n",
" p1 = self.p1_1(x)\n",
" p2 = self.p2_2(self.p2_1(x))\n",
" p3 = self.p3_2(self.p3_1(x))\n",
" p4 = self.p4_2(self.p4_1(x))\n",
" # 在通道维上连结输出\n",
" return np.concatenate((p1, p2, p3, p4), axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Inception 模型 - 第一阶段"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.601078Z",
"start_time": "2019-07-07T06:08:45.593944Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"b1 = nn.Sequential()\n",
"b1.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3, activation='relu'),\n",
" nn.MaxPool2D(pool_size=3, strides=2, padding=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inception 模型 - 第二阶段"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.606738Z",
"start_time": "2019-07-07T06:08:45.602846Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"b2 = nn.Sequential()\n",
"b2.add(nn.Conv2D(64, kernel_size=1),\n",
" nn.Conv2D(192, kernel_size=3, padding=1),\n",
" nn.MaxPool2D(pool_size=3, strides=2, padding=1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Inception 模型 - 第三阶段"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.615934Z",
"start_time": "2019-07-07T06:08:45.608077Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"b3 = nn.Sequential()\n",
"b3.add(Inception(64, (96, 128), (16, 32), 32),\n",
" Inception(128, (128, 192), (32, 96), 64),\n",
" nn.MaxPool2D(pool_size=3, strides=2, padding=1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Inception 模型 - 第四阶段。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.631945Z",
"start_time": "2019-07-07T06:08:45.617220Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [],
"source": [
"b4 = nn.Sequential()\n",
"b4.add(Inception(192, (96, 208), (16, 48), 64),\n",
" Inception(160, (112, 224), (24, 64), 64),\n",
" Inception(128, (128, 256), (24, 64), 64),\n",
" Inception(112, (144, 288), (32, 64), 64),\n",
" Inception(256, (160, 320), (32, 128), 128),\n",
" nn.MaxPool2D(pool_size=3, strides=2, padding=1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Inception 模型 - 第五阶段。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.641445Z",
"start_time": "2019-07-07T06:08:45.633251Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [],
"source": [
"b5 = nn.Sequential()\n",
"b5.add(Inception(256, (160, 320), (32, 128), 128),\n",
" Inception(384, (192, 384), (48, 128), 128),\n",
" nn.GlobalAvgPool2D())\n",
"\n",
"net = nn.Sequential()\n",
"net.add(b1, b2, b3, b4, b5, nn.Dense(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"查看网络。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-07T06:08:45.760919Z",
"start_time": "2019-07-07T06:08:45.643249Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "7"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sequential0 output shape:\t (1, 64, 24, 24)\n",
"sequential1 output shape:\t (1, 192, 12, 12)\n",
"sequential2 output shape:\t (1, 480, 6, 6)\n",
"sequential3 output shape:\t (1, 832, 3, 3)\n",
"sequential4 output shape:\t (1, 1024, 1, 1)\n",
"dense0 output shape:\t (1, 10)\n"
]
}
],
"source": [
"X = np.random.uniform(size=(1, 1, 96, 96))\n",
"net.initialize()\n",
"for layer in net:\n",
" X = layer(X)\n",
" print(layer.name, 'output shape:\\t', X.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"训练。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-07T06:08:42.014Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "8"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.336, train acc 0.873, test acc 0.879\n",
"2732.2 exampes/sec on gpu(0)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"