{
"cells": [
{
"cell_type": "markdown",
"id": "42454a9e",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Deep Convolutional Neural Networks (AlexNet)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "29feac8e",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:43:36.252627Z",
"iopub.status.busy": "2023-08-18T19:43:36.251926Z",
"iopub.status.idle": "2023-08-18T19:43:36.259662Z",
"shell.execute_reply": "2023-08-18T19:43:36.258841Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"class AlexNet(d2l.Classifier):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(\n",
" nn.LazyConv2d(96, kernel_size=11, stride=4, padding=1),\n",
" nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2),\n",
" nn.LazyConv2d(256, kernel_size=5, padding=2), nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2),\n",
" nn.LazyConv2d(384, kernel_size=3, padding=1), nn.ReLU(),\n",
" nn.LazyConv2d(384, kernel_size=3, padding=1), nn.ReLU(),\n",
" nn.LazyConv2d(256, kernel_size=3, padding=1), nn.ReLU(),\n",
" nn.MaxPool2d(kernel_size=3, stride=2), nn.Flatten(),\n",
" nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(p=0.5),\n",
" nn.LazyLinear(4096), nn.ReLU(),nn.Dropout(p=0.5),\n",
" nn.LazyLinear(num_classes))\n",
" self.net.apply(d2l.init_cnn)"
]
},
{
"cell_type": "markdown",
"id": "53d3c554",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Construct a single-channel data example\n",
"to observe the output shape of each layer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3d5c2c0a",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:43:36.262984Z",
"iopub.status.busy": "2023-08-18T19:43:36.262447Z",
"iopub.status.idle": "2023-08-18T19:43:36.786362Z",
"shell.execute_reply": "2023-08-18T19:43:36.785437Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Conv2d output shape:\t torch.Size([1, 96, 54, 54])\n",
"ReLU output shape:\t torch.Size([1, 96, 54, 54])\n",
"MaxPool2d output shape:\t torch.Size([1, 96, 26, 26])\n",
"Conv2d output shape:\t torch.Size([1, 256, 26, 26])\n",
"ReLU output shape:\t torch.Size([1, 256, 26, 26])\n",
"MaxPool2d output shape:\t torch.Size([1, 256, 12, 12])\n",
"Conv2d output shape:\t torch.Size([1, 384, 12, 12])\n",
"ReLU output shape:\t torch.Size([1, 384, 12, 12])\n",
"Conv2d output shape:\t torch.Size([1, 384, 12, 12])\n",
"ReLU output shape:\t torch.Size([1, 384, 12, 12])\n",
"Conv2d output shape:\t torch.Size([1, 256, 12, 12])\n",
"ReLU output shape:\t torch.Size([1, 256, 12, 12])\n",
"MaxPool2d output shape:\t torch.Size([1, 256, 5, 5])\n",
"Flatten output shape:\t torch.Size([1, 6400])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear output shape:\t torch.Size([1, 4096])\n",
"ReLU output shape:\t torch.Size([1, 4096])\n",
"Dropout output shape:\t torch.Size([1, 4096])\n",
"Linear output shape:\t torch.Size([1, 4096])\n",
"ReLU output shape:\t torch.Size([1, 4096])\n",
"Dropout output shape:\t torch.Size([1, 4096])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"AlexNet().layer_summary((1, 1, 224, 224))"
]
},
{
"cell_type": "markdown",
"id": "d155463a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Fashion-MNIST\n",
"images have lower resolution\n",
"than ImageNet images.\n",
"We upsample them to $224 \\times 224$\n",
"start training AlexNet"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "acd6a3e0",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:43:36.789914Z",
"iopub.status.busy": "2023-08-18T19:43:36.789357Z",
"iopub.status.idle": "2023-08-18T19:46:37.458518Z",
"shell.execute_reply": "2023-08-18T19:46:37.457483Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"