{
"cells": [
{
"cell_type": "markdown",
"id": "43fc4043",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Networks Using Blocks (VGG)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "89467a5c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:56:46.393555Z",
"iopub.status.busy": "2023-08-18T19:56:46.392607Z",
"iopub.status.idle": "2023-08-18T19:56:49.756697Z",
"shell.execute_reply": "2023-08-18T19:56:49.755534Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "0d8fc06f",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"VGG Blocks"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7a35971a",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:56:49.762934Z",
"iopub.status.busy": "2023-08-18T19:56:49.761989Z",
"iopub.status.idle": "2023-08-18T19:56:49.770418Z",
"shell.execute_reply": "2023-08-18T19:56:49.769006Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def vgg_block(num_convs, out_channels):\n",
" layers = []\n",
" for _ in range(num_convs):\n",
" layers.append(nn.LazyConv2d(out_channels, kernel_size=3, padding=1))\n",
" layers.append(nn.ReLU())\n",
" layers.append(nn.MaxPool2d(kernel_size=2,stride=2))\n",
" return nn.Sequential(*layers)"
]
},
{
"cell_type": "markdown",
"id": "bf32189a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"VGG Network"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fc110465",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:56:49.789248Z",
"iopub.status.busy": "2023-08-18T19:56:49.788373Z",
"iopub.status.idle": "2023-08-18T19:56:51.334656Z",
"shell.execute_reply": "2023-08-18T19:56:51.333439Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential output shape:\t torch.Size([1, 64, 112, 112])\n",
"Sequential output shape:\t torch.Size([1, 128, 56, 56])\n",
"Sequential output shape:\t torch.Size([1, 256, 28, 28])\n",
"Sequential output shape:\t torch.Size([1, 512, 14, 14])\n",
"Sequential output shape:\t torch.Size([1, 512, 7, 7])\n",
"Flatten output shape:\t torch.Size([1, 25088])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear output shape:\t torch.Size([1, 4096])\n",
"ReLU output shape:\t torch.Size([1, 4096])\n",
"Dropout output shape:\t torch.Size([1, 4096])\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear output shape:\t torch.Size([1, 4096])\n",
"ReLU output shape:\t torch.Size([1, 4096])\n",
"Dropout output shape:\t torch.Size([1, 4096])\n",
"Linear output shape:\t torch.Size([1, 10])\n"
]
}
],
"source": [
"class VGG(d2l.Classifier):\n",
" def __init__(self, arch, lr=0.1, num_classes=10):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" conv_blks = []\n",
" for (num_convs, out_channels) in arch:\n",
" conv_blks.append(vgg_block(num_convs, out_channels))\n",
" self.net = nn.Sequential(\n",
" *conv_blks, nn.Flatten(),\n",
" nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),\n",
" nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5),\n",
" nn.LazyLinear(num_classes))\n",
" self.net.apply(d2l.init_cnn)\n",
"\n",
"VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary(\n",
" (1, 1, 224, 224))"
]
},
{
"cell_type": "markdown",
"id": "64c9f7de",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Since VGG-11 is computationally more demanding than AlexNet\n",
"we construct a network with a smaller number of channels.\n",
"Model training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ed532028",
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
},
"execution": {
"iopub.execute_input": "2023-08-18T19:56:51.339409Z",
"iopub.status.busy": "2023-08-18T19:56:51.338575Z",
"iopub.status.idle": "2023-08-18T20:01:43.617050Z",
"shell.execute_reply": "2023-08-18T20:01:43.615832Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"