{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "d81a4359-d6fb-4756-b75c-8bc56289848d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "import sys\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "from torchsummary import summary\n", "\n", "\n", "def conv_block(num_channels):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=3, padding=1))\n", "\n", "def transition_block(num_channels):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=1),\n", " nn.AvgPool2d(kernel_size=2, stride=2))\n", "\n", "class DenseBlock(nn.Module):\n", " def __init__(self, num_convs, num_channels):\n", " super(DenseBlock, self).__init__()\n", " layer = []\n", " for i in range(num_convs):\n", " layer.append(conv_block(num_channels))\n", " self.net = nn.Sequential(*layer)\n", "\n", " def forward(self, X):\n", " for blk in self.net:\n", " Y = blk(X)\n", " # Concatenate input and output of each block along the channels\n", " X = torch.cat((X, Y), dim=1)\n", " return X\n", " \n", "class DenseNet(d2l.Classifier):\n", " def b1(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", " \n", " def __init__(self, num_channels=64, growth_rate=32, arch=(4, 4, 4, 4),\n", " lr=0.1, num_classes=10):\n", " super(DenseNet, self).__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(self.b1())\n", " for i, num_convs in enumerate(arch):\n", " self.net.add_module(f'dense_blk{i+1}', DenseBlock(num_convs,\n", " growth_rate))\n", " # The number of output channels in the previous dense block\n", " num_channels += num_convs * growth_rate\n", " # A transition layer that halves the number of channels is added\n", " # between the dense blocks\n", " if i != len(arch) - 1:\n", " num_channels //= 2\n", " self.net.add_module(f'tran_blk{i+1}', transition_block(\n", " num_channels))\n", " self.net.add_module('last', nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n", " nn.LazyLinear(num_classes)))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "markdown", "id": "79f3d2a8-0b65-435b-9a3c-3b2e6f5f79f3", "metadata": {}, "source": [ "# 1. Why do we use average pooling rather than max-pooling in the transition layer?" ] }, { "cell_type": "markdown", "id": "4419d1bb-cda9-4a21-96a7-b6f7000ee6af", "metadata": {}, "source": [ "In DenseNet architectures, transition layers are used to reduce the spatial dimensions (width and height) of feature maps while also reducing the number of feature maps (channels) before passing them to the next dense block. The choice between average pooling and max-pooling in transition layers depends on the design goals and the desired properties of the network. In DenseNet, average pooling is often preferred over max-pooling for several reasons:\n", "\n", "1. **Feature Retention**: Average pooling computes the average value of the elements in a pooling region. This retains more information about the features compared to max-pooling, which only selects the maximum value. In DenseNet, where information from all previous layers is concatenated together, average pooling helps in maintaining a more comprehensive representation of the features.\n", "\n", "2. **Smoothing Effect**: Average pooling has a smoothing effect on the output feature maps. This can help in reducing the risk of overfitting by preventing the network from becoming too sensitive to specific details in the data.\n", "\n", "3. **Stability**: Average pooling is less sensitive to outliers compared to max-pooling. This can make the network more robust to noise or variations in the input data.\n", "\n", "4. **Translation Invariance**: Average pooling provides a certain degree of translation invariance by taking into account the overall distribution of values in the pooling region. This can be beneficial in scenarios where small translations of the input should not significantly affect the output.\n", "\n", "5. **Information Sharing**: Average pooling promotes information sharing among neighboring pixels or units. This can help in capturing global patterns and structures present in the input data.\n", "\n", "While average pooling is preferred in transition layers, max-pooling can still have its own advantages in certain contexts. For example, in architectures like convolutional neural networks (CNNs) that prioritize capturing local features and enhancing feature maps, max-pooling can be effective. However, in DenseNet's context, where the emphasis is on maintaining rich information flow and reducing the risk of information loss, average pooling aligns better with the architecture's principles.\n", "\n", "Ultimately, the choice between average pooling and max-pooling depends on the specific goals of the network, the characteristics of the data, and the overall design philosophy." ] }, { "cell_type": "markdown", "id": "2f878475-5b4d-4e95-9f61-2f632dc4c824", "metadata": {}, "source": [ "# 2. One of the advantages mentioned in the DenseNet paper is that its model parameters are smaller than those of ResNet. Why is this the case?" ] }, { "cell_type": "code", "execution_count": 11, "id": "c3229269-d43e-452b-8db7-35b0425b98e0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "11523338\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 112, 112] 3,200\n", " BatchNorm2d-2 [-1, 64, 112, 112] 128\n", " ReLU-3 [-1, 64, 112, 112] 0\n", " MaxPool2d-4 [-1, 64, 56, 56] 0\n", " Conv2d-5 [-1, 64, 56, 56] 36,928\n", " BatchNorm2d-6 [-1, 64, 56, 56] 128\n", " ReLU-7 [-1, 64, 56, 56] 0\n", " Conv2d-8 [-1, 64, 56, 56] 36,928\n", " BatchNorm2d-9 [-1, 64, 56, 56] 128\n", " Residual-10 [-1, 64, 56, 56] 0\n", " Conv2d-11 [-1, 64, 56, 56] 36,928\n", " BatchNorm2d-12 [-1, 64, 56, 56] 128\n", " ReLU-13 [-1, 64, 56, 56] 0\n", " Conv2d-14 [-1, 64, 56, 56] 36,928\n", " BatchNorm2d-15 [-1, 64, 56, 56] 128\n", " Residual-16 [-1, 64, 56, 56] 0\n", " Conv2d-17 [-1, 128, 28, 28] 73,856\n", " BatchNorm2d-18 [-1, 128, 28, 28] 256\n", " ReLU-19 [-1, 128, 28, 28] 0\n", " Conv2d-20 [-1, 128, 28, 28] 147,584\n", " BatchNorm2d-21 [-1, 128, 28, 28] 256\n", " Conv2d-22 [-1, 128, 28, 28] 8,320\n", " Residual-23 [-1, 128, 28, 28] 0\n", " Conv2d-24 [-1, 128, 28, 28] 147,584\n", " BatchNorm2d-25 [-1, 128, 28, 28] 256\n", " ReLU-26 [-1, 128, 28, 28] 0\n", " Conv2d-27 [-1, 128, 28, 28] 147,584\n", " BatchNorm2d-28 [-1, 128, 28, 28] 256\n", " Conv2d-29 [-1, 128, 28, 28] 16,512\n", " Residual-30 [-1, 128, 28, 28] 0\n", " Conv2d-31 [-1, 256, 14, 14] 295,168\n", " BatchNorm2d-32 [-1, 256, 14, 14] 512\n", " ReLU-33 [-1, 256, 14, 14] 0\n", " Conv2d-34 [-1, 256, 14, 14] 590,080\n", " BatchNorm2d-35 [-1, 256, 14, 14] 512\n", " Conv2d-36 [-1, 256, 14, 14] 33,024\n", " Residual-37 [-1, 256, 14, 14] 0\n", " Conv2d-38 [-1, 256, 14, 14] 590,080\n", " BatchNorm2d-39 [-1, 256, 14, 14] 512\n", " ReLU-40 [-1, 256, 14, 14] 0\n", " Conv2d-41 [-1, 256, 14, 14] 590,080\n", " BatchNorm2d-42 [-1, 256, 14, 14] 512\n", " Conv2d-43 [-1, 256, 14, 14] 65,792\n", " Residual-44 [-1, 256, 14, 14] 0\n", " Conv2d-45 [-1, 512, 7, 7] 1,180,160\n", " BatchNorm2d-46 [-1, 512, 7, 7] 1,024\n", " ReLU-47 [-1, 512, 7, 7] 0\n", " Conv2d-48 [-1, 512, 7, 7] 2,359,808\n", " BatchNorm2d-49 [-1, 512, 7, 7] 1,024\n", " Conv2d-50 [-1, 512, 7, 7] 131,584\n", " Residual-51 [-1, 512, 7, 7] 0\n", " Conv2d-52 [-1, 512, 7, 7] 2,359,808\n", " BatchNorm2d-53 [-1, 512, 7, 7] 1,024\n", " ReLU-54 [-1, 512, 7, 7] 0\n", " Conv2d-55 [-1, 512, 7, 7] 2,359,808\n", " BatchNorm2d-56 [-1, 512, 7, 7] 1,024\n", " Conv2d-57 [-1, 512, 7, 7] 262,656\n", " Residual-58 [-1, 512, 7, 7] 0\n", "AdaptiveAvgPool2d-59 [-1, 512, 1, 1] 0\n", " Flatten-60 [-1, 512] 0\n", " Linear-61 [-1, 10] 5,130\n", "================================================================\n", "Total params: 11,523,338\n", "Trainable params: 11,523,338\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 57.05\n", "Params size (MB): 43.96\n", "Estimated Total Size (MB): 101.20\n", "----------------------------------------------------------------\n" ] } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters())\n", "\n", "data = d2l.FashionMNIST(batch_size=32, resize=(224, 224))\n", "arch18 = [(2,[(64,3,1)]*2,None),(2,[(128,3,1)]*2,128),(2,[(256,3,1)]*2,256),(2,[(512,3,1)]*2,512)]\n", "resnet18 = d2l.ResNet(arch=arch18, lr=0.01)\n", "resnet18.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "print(count_parameters(resnet18))\n", "summary(resnet18, (1, 224, 224))" ] }, { "cell_type": "code", "execution_count": 12, "id": "dd8a4dc1-ffbc-43f8-a2f2-beb5fd8fc325", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "758226\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 112, 112] 3,200\n", " BatchNorm2d-2 [-1, 64, 112, 112] 128\n", " ReLU-3 [-1, 64, 112, 112] 0\n", " MaxPool2d-4 [-1, 64, 56, 56] 0\n", " BatchNorm2d-5 [-1, 64, 56, 56] 128\n", " ReLU-6 [-1, 64, 56, 56] 0\n", " Conv2d-7 [-1, 32, 56, 56] 18,464\n", " BatchNorm2d-8 [-1, 96, 56, 56] 192\n", " ReLU-9 [-1, 96, 56, 56] 0\n", " Conv2d-10 [-1, 32, 56, 56] 27,680\n", " BatchNorm2d-11 [-1, 128, 56, 56] 256\n", " ReLU-12 [-1, 128, 56, 56] 0\n", " Conv2d-13 [-1, 32, 56, 56] 36,896\n", " BatchNorm2d-14 [-1, 160, 56, 56] 320\n", " ReLU-15 [-1, 160, 56, 56] 0\n", " Conv2d-16 [-1, 32, 56, 56] 46,112\n", " DenseBlock-17 [-1, 192, 56, 56] 0\n", " BatchNorm2d-18 [-1, 192, 56, 56] 384\n", " ReLU-19 [-1, 192, 56, 56] 0\n", " Conv2d-20 [-1, 96, 56, 56] 18,528\n", " AvgPool2d-21 [-1, 96, 28, 28] 0\n", " BatchNorm2d-22 [-1, 96, 28, 28] 192\n", " ReLU-23 [-1, 96, 28, 28] 0\n", " Conv2d-24 [-1, 32, 28, 28] 27,680\n", " BatchNorm2d-25 [-1, 128, 28, 28] 256\n", " ReLU-26 [-1, 128, 28, 28] 0\n", " Conv2d-27 [-1, 32, 28, 28] 36,896\n", " BatchNorm2d-28 [-1, 160, 28, 28] 320\n", " ReLU-29 [-1, 160, 28, 28] 0\n", " Conv2d-30 [-1, 32, 28, 28] 46,112\n", " BatchNorm2d-31 [-1, 192, 28, 28] 384\n", " ReLU-32 [-1, 192, 28, 28] 0\n", " Conv2d-33 [-1, 32, 28, 28] 55,328\n", " DenseBlock-34 [-1, 224, 28, 28] 0\n", " BatchNorm2d-35 [-1, 224, 28, 28] 448\n", " ReLU-36 [-1, 224, 28, 28] 0\n", " Conv2d-37 [-1, 112, 28, 28] 25,200\n", " AvgPool2d-38 [-1, 112, 14, 14] 0\n", " BatchNorm2d-39 [-1, 112, 14, 14] 224\n", " ReLU-40 [-1, 112, 14, 14] 0\n", " Conv2d-41 [-1, 32, 14, 14] 32,288\n", " BatchNorm2d-42 [-1, 144, 14, 14] 288\n", " ReLU-43 [-1, 144, 14, 14] 0\n", " Conv2d-44 [-1, 32, 14, 14] 41,504\n", " BatchNorm2d-45 [-1, 176, 14, 14] 352\n", " ReLU-46 [-1, 176, 14, 14] 0\n", " Conv2d-47 [-1, 32, 14, 14] 50,720\n", " BatchNorm2d-48 [-1, 208, 14, 14] 416\n", " ReLU-49 [-1, 208, 14, 14] 0\n", " Conv2d-50 [-1, 32, 14, 14] 59,936\n", " DenseBlock-51 [-1, 240, 14, 14] 0\n", " BatchNorm2d-52 [-1, 240, 14, 14] 480\n", " ReLU-53 [-1, 240, 14, 14] 0\n", " Conv2d-54 [-1, 120, 14, 14] 28,920\n", " AvgPool2d-55 [-1, 120, 7, 7] 0\n", " BatchNorm2d-56 [-1, 120, 7, 7] 240\n", " ReLU-57 [-1, 120, 7, 7] 0\n", " Conv2d-58 [-1, 32, 7, 7] 34,592\n", " BatchNorm2d-59 [-1, 152, 7, 7] 304\n", " ReLU-60 [-1, 152, 7, 7] 0\n", " Conv2d-61 [-1, 32, 7, 7] 43,808\n", " BatchNorm2d-62 [-1, 184, 7, 7] 368\n", " ReLU-63 [-1, 184, 7, 7] 0\n", " Conv2d-64 [-1, 32, 7, 7] 53,024\n", " BatchNorm2d-65 [-1, 216, 7, 7] 432\n", " ReLU-66 [-1, 216, 7, 7] 0\n", " Conv2d-67 [-1, 32, 7, 7] 62,240\n", " DenseBlock-68 [-1, 248, 7, 7] 0\n", " BatchNorm2d-69 [-1, 248, 7, 7] 496\n", " ReLU-70 [-1, 248, 7, 7] 0\n", "AdaptiveAvgPool2d-71 [-1, 248, 1, 1] 0\n", " Flatten-72 [-1, 248] 0\n", " Linear-73 [-1, 10] 2,490\n", "================================================================\n", "Total params: 758,226\n", "Trainable params: 758,226\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 77.81\n", "Params size (MB): 2.89\n", "Estimated Total Size (MB): 80.89\n", "----------------------------------------------------------------\n" ] } ], "source": [ "model = DenseNet(lr=0.01)\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "print(count_parameters(model))\n", "summary(model, (1, 224, 224))" ] }, { "cell_type": "markdown", "id": "28122a78-6c52-4a42-8de0-645d1d0667bd", "metadata": {}, "source": [ "The reason why DenseNet has smaller model parameters than ResNet is because DenseNet uses **dense connections** between layers, which means that each layer receives the feature maps of all preceding layers as input and passes its own feature maps to all subsequent layers. This way, the number of channels (filters) in each layer can be reduced, since the layer can reuse the features from previous layers. ResNet, on the other hand, uses **residual connections**, which means that each layer only receives the output of the previous layer and adds it to its own output. This requires more channels in each layer to learn new features, since the layer cannot access the features from earlier layers. According to the DenseNet paper¹, a 121-layer DenseNet has 7.98 million parameters, while a 152-layer ResNet has 60.19 million parameters. This is a significant difference in model size and complexity.\n", "\n", "- (1) [1608.06993] Densely Connected Convolutional Networks - arXiv.org. https://arxiv.org/abs/1608.06993.\n", "- (2) DenseNet Explained | Papers With Code. https://paperswithcode.com/method/densenet.\n", "- (3) CVPR2017最佳论文DenseNet(一)原理分析 - 知乎 - 知乎专栏. https://zhuanlan.zhihu.com/p/93825208.\n", "- (4) Densely Connected Convolutional Networks - arXiv.org. https://arxiv.org/pdf/1608.06993.pdf.\n", "- (5) [1512.03385] Deep Residual Learning for Image Recognition - arXiv.org. https://arxiv.org/abs/1512.03385.\n", "- (6) ResNet Explained | Papers With Code. https://paperswithcode.com/method/resnet.\n", "- (7) arXiv:2110.00476v1 [cs.CV] 1 Oct 2021. https://arxiv.org/pdf/2110.00476.pdf.\n", "- (8) undefined. https://doi.org/10.48550/arXiv.1608.06993.\n", "- (9) undefined. https://doi.org/10.48550/arXiv.1512.03385." ] }, { "cell_type": "markdown", "id": "78eb9463-1ac7-46ca-b736-cd1f32922164", "metadata": {}, "source": [ "# 3. One problem for which DenseNet has been criticized is its high memory consumption." ] }, { "cell_type": "markdown", "id": "e024d971-a075-4eb9-8df8-d3526c45db52", "metadata": {}, "source": [ "## 3.1 Is this really the case? Try to change the input shape to $224\\times 224$ to compare the actual GPU memory consumption empirically." ] }, { "cell_type": "code", "execution_count": null, "id": "83abcc55-9961-49ce-b868-4b8e86e504c3", "metadata": {}, "outputs": [], "source": [ "data = d2l.FashionMNIST(batch_size=32, resize=(28, 28))\n", "model = DenseNet(lr=0.01)\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "torch.cuda.reset_peak_memory_stats()\n", "torch.cuda.empty_cache()\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "trainer.fit(model, data)\n", "memory_stats = torch.cuda.memory_stats(device=device)\n", "# Print peak memory usage and other memory statistics\n", "print(\"Peak memory usage:\", memory_stats[\"allocated_bytes.all.peak\"] / (1024 ** 2), \"MB\")\n", "print(\"Current memory usage:\", memory_stats[\"allocated_bytes.all.current\"] / (1024 ** 2), \"MB\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5a641b40-b510-48d5-bf4e-99063380773b", "metadata": {}, "outputs": [], "source": [ "data = d2l.FashionMNIST(batch_size=32, resize=(224, 224))\n", "model = DenseNet(lr=0.01)\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "torch.cuda.reset_peak_memory_stats()\n", "torch.cuda.empty_cache()\n", "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "trainer.fit(model, data)\n", "memory_stats = torch.cuda.memory_stats(device=device)\n", "# Print peak memory usage and other memory statistics\n", "print(\"Peak memory usage:\", memory_stats[\"allocated_bytes.all.peak\"] / (1024 ** 2), \"MB\")\n", "print(\"Current memory usage:\", memory_stats[\"allocated_bytes.all.current\"] / (1024 ** 2), \"MB\")" ] }, { "cell_type": "markdown", "id": "0b4c4bb3-4a25-42a9-b4b6-e33f3b7bdc0b", "metadata": {}, "source": [ "## 3.2 Can you think of an alternative means of reducing the memory consumption? How would you need to change the framework?" ] }, { "cell_type": "markdown", "id": "cf265e43-e19f-4a7f-8fbd-415a110e5a82", "metadata": {}, "source": [ "Reducing memory consumption in a DenseNet architecture can be achieved through various strategies. One approach is to introduce sparsity into the model, which reduces the number of active connections and parameters. Here's how you might change the framework to achieve this:\n", "\n", "**1. Sparse Connectivity in Dense Blocks:**\n", "Instead of having fully connected dense blocks, you can introduce sparse connectivity patterns. This means that not every layer connects to every other layer in the dense block. You can achieve this by randomly selecting a subset of previous layers' feature maps to concatenate with the current layer. This reduces the number of connections and memory consumption.\n", "\n", "**2. Channel Pruning:**\n", "Apply channel pruning techniques to the dense blocks. You can identify less important channels and remove them from the concatenation operation. This effectively reduces the number of active channels and saves memory.\n", "\n", "**3. Regularization and Compression:**\n", "Introduce regularization techniques like L1 regularization during training to encourage certain weights to become exactly zero. Additionally, you can explore model compression methods like knowledge distillation or quantization to reduce the memory footprint of the model.\n", "\n", "**4. Low-Rank Approximations:**\n", "Perform low-rank matrix factorization on the weight matrices in the dense blocks. This technique approximates the weight matrices with lower-dimensional factors, leading to reduced memory usage.\n", "\n", "**5. Dynamic Allocation:**\n", "Allocate memory dynamically during inference to only store the necessary feature maps. This technique avoids allocating memory for feature maps that are no longer needed.\n", "\n", "**6. Sparsity-Inducing Activation Functions:**\n", "Use activation functions that naturally induce sparsity, such as the ReLU6 function, which caps activations at a maximum value and can lead to some neurons becoming inactive.\n", "\n", "**7. Adaptive Dense Blocks:**\n", "Design adaptive dense blocks that dynamically adjust their connectivity patterns based on the data distribution. For example, you can use attention mechanisms to determine which previous feature maps to concatenate based on their importance.\n", "\n", "Implementing these changes would require modifications to the architecture, training procedure, and potentially custom layers or modifications to existing layers. It's important to note that these techniques might involve a trade-off between memory reduction and model performance. It's recommended to experiment and fine-tune these strategies on your specific problem domain to find the right balance." ] }, { "cell_type": "markdown", "id": "74854d69-300d-4521-92c5-a5125886286e", "metadata": {}, "source": [ "# 4. Implement the various DenseNet versions presented in Table 1 of the DenseNet paper (Huang et al., 2017)." ] }, { "cell_type": "code", "execution_count": 114, "id": "664cfc4d-a78b-4eee-899f-016409aaae81", "metadata": { "tags": [] }, "outputs": [], "source": [ "def conv_block(num_channels, kernel_size, padding):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=kernel_size, padding=padding))\n", "\n", "def transition_block(num_channels):\n", " return nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.LazyConv2d(num_channels, kernel_size=1),\n", " nn.AvgPool2d(kernel_size=2, stride=2))\n", "\n", "class DenseBlock(nn.Module):\n", " def __init__(self, convs, num_channels):\n", " super(DenseBlock, self).__init__()\n", " layer = []\n", " for kernel_size, padding in convs:\n", " layer.append(conv_block(num_channels, kernel_size, padding))\n", " self.net = nn.Sequential(*layer)\n", "\n", " def forward(self, X):\n", " for blk in self.net:\n", " Y = blk(X)\n", " # Concatenate input and output of each block along the channels\n", " X = torch.cat((X, Y), dim=1)\n", " return X\n", "\n", "class DenseNet(d2l.Classifier):\n", " def b1(self):\n", " return nn.Sequential(\n", " nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=3, stride=2, padding=1))\n", " \n", " def __init__(self, num_channels=64, growth_rate=32, arch=[[[3,1],[3,1]],[[3,1],[3,1]]],lr=0.1, num_classes=10):\n", " super(DenseNet, self).__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(self.b1())\n", " for i, convs in enumerate(arch):\n", " self.net.add_module(f'dense_blk{i+1}', DenseBlock(convs, growth_rate))\n", " # The number of output channels in the previous dense block\n", " num_channels += len(convs) * growth_rate\n", " # A transition layer that halves the number of channels is added\n", " # between the dense blocks\n", " if i != len(arch) - 1:\n", " num_channels //= 2\n", " self.net.add_module(f'tran_blk{i+1}', transition_block(\n", " num_channels))\n", " self.net.add_module('last', nn.Sequential(\n", " nn.LazyBatchNorm2d(), nn.ReLU(),\n", " nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),\n", " nn.LazyLinear(num_classes)))\n", " self.net.apply(d2l.init_cnn)" ] }, { "cell_type": "code", "execution_count": 118, "id": "ee7bb0c6-0a68-46d1-aece-613dce70ae43", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 112, 112] 3,200\n", " BatchNorm2d-2 [-1, 64, 112, 112] 128\n", " ReLU-3 [-1, 64, 112, 112] 0\n", " MaxPool2d-4 [-1, 64, 56, 56] 0\n", " BatchNorm2d-5 [-1, 64, 56, 56] 128\n", " ReLU-6 [-1, 64, 56, 56] 0\n", " Conv2d-7 [-1, 32, 56, 56] 2,080\n", " BatchNorm2d-8 [-1, 96, 56, 56] 192\n", " ReLU-9 [-1, 96, 56, 56] 0\n", " Conv2d-10 [-1, 32, 56, 56] 27,680\n", " BatchNorm2d-11 [-1, 128, 56, 56] 256\n", " ReLU-12 [-1, 128, 56, 56] 0\n", " Conv2d-13 [-1, 32, 56, 56] 4,128\n", " BatchNorm2d-14 [-1, 160, 56, 56] 320\n", " ReLU-15 [-1, 160, 56, 56] 0\n", " Conv2d-16 [-1, 32, 56, 56] 46,112\n", " BatchNorm2d-17 [-1, 192, 56, 56] 384\n", " ReLU-18 [-1, 192, 56, 56] 0\n", " Conv2d-19 [-1, 32, 56, 56] 6,176\n", " BatchNorm2d-20 [-1, 224, 56, 56] 448\n", " ReLU-21 [-1, 224, 56, 56] 0\n", " Conv2d-22 [-1, 32, 56, 56] 64,544\n", " BatchNorm2d-23 [-1, 256, 56, 56] 512\n", " ReLU-24 [-1, 256, 56, 56] 0\n", " Conv2d-25 [-1, 32, 56, 56] 8,224\n", " BatchNorm2d-26 [-1, 288, 56, 56] 576\n", " ReLU-27 [-1, 288, 56, 56] 0\n", " Conv2d-28 [-1, 32, 56, 56] 82,976\n", " BatchNorm2d-29 [-1, 320, 56, 56] 640\n", " ReLU-30 [-1, 320, 56, 56] 0\n", " Conv2d-31 [-1, 32, 56, 56] 10,272\n", " BatchNorm2d-32 [-1, 352, 56, 56] 704\n", " ReLU-33 [-1, 352, 56, 56] 0\n", " Conv2d-34 [-1, 32, 56, 56] 101,408\n", " BatchNorm2d-35 [-1, 384, 56, 56] 768\n", " ReLU-36 [-1, 384, 56, 56] 0\n", " Conv2d-37 [-1, 32, 56, 56] 12,320\n", " BatchNorm2d-38 [-1, 416, 56, 56] 832\n", " ReLU-39 [-1, 416, 56, 56] 0\n", " Conv2d-40 [-1, 32, 56, 56] 119,840\n", " DenseBlock-41 [-1, 448, 56, 56] 0\n", " BatchNorm2d-42 [-1, 448, 56, 56] 896\n", " ReLU-43 [-1, 448, 56, 56] 0\n", " Conv2d-44 [-1, 224, 56, 56] 100,576\n", " AvgPool2d-45 [-1, 224, 28, 28] 0\n", " BatchNorm2d-46 [-1, 224, 28, 28] 448\n", " ReLU-47 [-1, 224, 28, 28] 0\n", " Conv2d-48 [-1, 32, 28, 28] 7,200\n", " BatchNorm2d-49 [-1, 256, 28, 28] 512\n", " ReLU-50 [-1, 256, 28, 28] 0\n", " Conv2d-51 [-1, 32, 28, 28] 73,760\n", " BatchNorm2d-52 [-1, 288, 28, 28] 576\n", " ReLU-53 [-1, 288, 28, 28] 0\n", " Conv2d-54 [-1, 32, 28, 28] 9,248\n", " BatchNorm2d-55 [-1, 320, 28, 28] 640\n", " ReLU-56 [-1, 320, 28, 28] 0\n", " Conv2d-57 [-1, 32, 28, 28] 92,192\n", " BatchNorm2d-58 [-1, 352, 28, 28] 704\n", " ReLU-59 [-1, 352, 28, 28] 0\n", " Conv2d-60 [-1, 32, 28, 28] 11,296\n", " BatchNorm2d-61 [-1, 384, 28, 28] 768\n", " ReLU-62 [-1, 384, 28, 28] 0\n", " Conv2d-63 [-1, 32, 28, 28] 110,624\n", " BatchNorm2d-64 [-1, 416, 28, 28] 832\n", " ReLU-65 [-1, 416, 28, 28] 0\n", " Conv2d-66 [-1, 32, 28, 28] 13,344\n", " BatchNorm2d-67 [-1, 448, 28, 28] 896\n", " ReLU-68 [-1, 448, 28, 28] 0\n", " Conv2d-69 [-1, 32, 28, 28] 129,056\n", " BatchNorm2d-70 [-1, 480, 28, 28] 960\n", " ReLU-71 [-1, 480, 28, 28] 0\n", " Conv2d-72 [-1, 32, 28, 28] 15,392\n", " BatchNorm2d-73 [-1, 512, 28, 28] 1,024\n", " ReLU-74 [-1, 512, 28, 28] 0\n", " Conv2d-75 [-1, 32, 28, 28] 147,488\n", " BatchNorm2d-76 [-1, 544, 28, 28] 1,088\n", " ReLU-77 [-1, 544, 28, 28] 0\n", " Conv2d-78 [-1, 32, 28, 28] 17,440\n", " BatchNorm2d-79 [-1, 576, 28, 28] 1,152\n", " ReLU-80 [-1, 576, 28, 28] 0\n", " Conv2d-81 [-1, 32, 28, 28] 165,920\n", " BatchNorm2d-82 [-1, 608, 28, 28] 1,216\n", " ReLU-83 [-1, 608, 28, 28] 0\n", " Conv2d-84 [-1, 32, 28, 28] 19,488\n", " BatchNorm2d-85 [-1, 640, 28, 28] 1,280\n", " ReLU-86 [-1, 640, 28, 28] 0\n", " Conv2d-87 [-1, 32, 28, 28] 184,352\n", " BatchNorm2d-88 [-1, 672, 28, 28] 1,344\n", " ReLU-89 [-1, 672, 28, 28] 0\n", " Conv2d-90 [-1, 32, 28, 28] 21,536\n", " BatchNorm2d-91 [-1, 704, 28, 28] 1,408\n", " ReLU-92 [-1, 704, 28, 28] 0\n", " Conv2d-93 [-1, 32, 28, 28] 202,784\n", " BatchNorm2d-94 [-1, 736, 28, 28] 1,472\n", " ReLU-95 [-1, 736, 28, 28] 0\n", " Conv2d-96 [-1, 32, 28, 28] 23,584\n", " BatchNorm2d-97 [-1, 768, 28, 28] 1,536\n", " ReLU-98 [-1, 768, 28, 28] 0\n", " Conv2d-99 [-1, 32, 28, 28] 221,216\n", " BatchNorm2d-100 [-1, 800, 28, 28] 1,600\n", " ReLU-101 [-1, 800, 28, 28] 0\n", " Conv2d-102 [-1, 32, 28, 28] 25,632\n", " BatchNorm2d-103 [-1, 832, 28, 28] 1,664\n", " ReLU-104 [-1, 832, 28, 28] 0\n", " Conv2d-105 [-1, 32, 28, 28] 239,648\n", " BatchNorm2d-106 [-1, 864, 28, 28] 1,728\n", " ReLU-107 [-1, 864, 28, 28] 0\n", " Conv2d-108 [-1, 32, 28, 28] 27,680\n", " BatchNorm2d-109 [-1, 896, 28, 28] 1,792\n", " ReLU-110 [-1, 896, 28, 28] 0\n", " Conv2d-111 [-1, 32, 28, 28] 258,080\n", " BatchNorm2d-112 [-1, 928, 28, 28] 1,856\n", " ReLU-113 [-1, 928, 28, 28] 0\n", " Conv2d-114 [-1, 32, 28, 28] 29,728\n", " BatchNorm2d-115 [-1, 960, 28, 28] 1,920\n", " ReLU-116 [-1, 960, 28, 28] 0\n", " Conv2d-117 [-1, 32, 28, 28] 276,512\n", " DenseBlock-118 [-1, 992, 28, 28] 0\n", " BatchNorm2d-119 [-1, 992, 28, 28] 1,984\n", " ReLU-120 [-1, 992, 28, 28] 0\n", " Conv2d-121 [-1, 496, 28, 28] 492,528\n", " AvgPool2d-122 [-1, 496, 14, 14] 0\n", " BatchNorm2d-123 [-1, 496, 14, 14] 992\n", " ReLU-124 [-1, 496, 14, 14] 0\n", " Conv2d-125 [-1, 32, 14, 14] 15,904\n", " BatchNorm2d-126 [-1, 528, 14, 14] 1,056\n", " ReLU-127 [-1, 528, 14, 14] 0\n", " Conv2d-128 [-1, 32, 14, 14] 152,096\n", " BatchNorm2d-129 [-1, 560, 14, 14] 1,120\n", " ReLU-130 [-1, 560, 14, 14] 0\n", " Conv2d-131 [-1, 32, 14, 14] 17,952\n", " BatchNorm2d-132 [-1, 592, 14, 14] 1,184\n", " ReLU-133 [-1, 592, 14, 14] 0\n", " Conv2d-134 [-1, 32, 14, 14] 170,528\n", " BatchNorm2d-135 [-1, 624, 14, 14] 1,248\n", " ReLU-136 [-1, 624, 14, 14] 0\n", " Conv2d-137 [-1, 32, 14, 14] 20,000\n", " BatchNorm2d-138 [-1, 656, 14, 14] 1,312\n", " ReLU-139 [-1, 656, 14, 14] 0\n", " Conv2d-140 [-1, 32, 14, 14] 188,960\n", " BatchNorm2d-141 [-1, 688, 14, 14] 1,376\n", " ReLU-142 [-1, 688, 14, 14] 0\n", " Conv2d-143 [-1, 32, 14, 14] 22,048\n", " BatchNorm2d-144 [-1, 720, 14, 14] 1,440\n", " ReLU-145 [-1, 720, 14, 14] 0\n", " Conv2d-146 [-1, 32, 14, 14] 207,392\n", " BatchNorm2d-147 [-1, 752, 14, 14] 1,504\n", " ReLU-148 [-1, 752, 14, 14] 0\n", " Conv2d-149 [-1, 32, 14, 14] 24,096\n", " BatchNorm2d-150 [-1, 784, 14, 14] 1,568\n", " ReLU-151 [-1, 784, 14, 14] 0\n", " Conv2d-152 [-1, 32, 14, 14] 225,824\n", " BatchNorm2d-153 [-1, 816, 14, 14] 1,632\n", " ReLU-154 [-1, 816, 14, 14] 0\n", " Conv2d-155 [-1, 32, 14, 14] 26,144\n", " BatchNorm2d-156 [-1, 848, 14, 14] 1,696\n", " ReLU-157 [-1, 848, 14, 14] 0\n", " Conv2d-158 [-1, 32, 14, 14] 244,256\n", " BatchNorm2d-159 [-1, 880, 14, 14] 1,760\n", " ReLU-160 [-1, 880, 14, 14] 0\n", " Conv2d-161 [-1, 32, 14, 14] 28,192\n", " BatchNorm2d-162 [-1, 912, 14, 14] 1,824\n", " ReLU-163 [-1, 912, 14, 14] 0\n", " Conv2d-164 [-1, 32, 14, 14] 262,688\n", " BatchNorm2d-165 [-1, 944, 14, 14] 1,888\n", " ReLU-166 [-1, 944, 14, 14] 0\n", " Conv2d-167 [-1, 32, 14, 14] 30,240\n", " BatchNorm2d-168 [-1, 976, 14, 14] 1,952\n", " ReLU-169 [-1, 976, 14, 14] 0\n", " Conv2d-170 [-1, 32, 14, 14] 281,120\n", " BatchNorm2d-171 [-1, 1008, 14, 14] 2,016\n", " ReLU-172 [-1, 1008, 14, 14] 0\n", " Conv2d-173 [-1, 32, 14, 14] 32,288\n", " BatchNorm2d-174 [-1, 1040, 14, 14] 2,080\n", " ReLU-175 [-1, 1040, 14, 14] 0\n", " Conv2d-176 [-1, 32, 14, 14] 299,552\n", " BatchNorm2d-177 [-1, 1072, 14, 14] 2,144\n", " ReLU-178 [-1, 1072, 14, 14] 0\n", " Conv2d-179 [-1, 32, 14, 14] 34,336\n", " BatchNorm2d-180 [-1, 1104, 14, 14] 2,208\n", " ReLU-181 [-1, 1104, 14, 14] 0\n", " Conv2d-182 [-1, 32, 14, 14] 317,984\n", " BatchNorm2d-183 [-1, 1136, 14, 14] 2,272\n", " ReLU-184 [-1, 1136, 14, 14] 0\n", " Conv2d-185 [-1, 32, 14, 14] 36,384\n", " BatchNorm2d-186 [-1, 1168, 14, 14] 2,336\n", " ReLU-187 [-1, 1168, 14, 14] 0\n", " Conv2d-188 [-1, 32, 14, 14] 336,416\n", " BatchNorm2d-189 [-1, 1200, 14, 14] 2,400\n", " ReLU-190 [-1, 1200, 14, 14] 0\n", " Conv2d-191 [-1, 32, 14, 14] 38,432\n", " BatchNorm2d-192 [-1, 1232, 14, 14] 2,464\n", " ReLU-193 [-1, 1232, 14, 14] 0\n", " Conv2d-194 [-1, 32, 14, 14] 354,848\n", " BatchNorm2d-195 [-1, 1264, 14, 14] 2,528\n", " ReLU-196 [-1, 1264, 14, 14] 0\n", " Conv2d-197 [-1, 32, 14, 14] 40,480\n", " BatchNorm2d-198 [-1, 1296, 14, 14] 2,592\n", " ReLU-199 [-1, 1296, 14, 14] 0\n", " Conv2d-200 [-1, 32, 14, 14] 373,280\n", " BatchNorm2d-201 [-1, 1328, 14, 14] 2,656\n", " ReLU-202 [-1, 1328, 14, 14] 0\n", " Conv2d-203 [-1, 32, 14, 14] 42,528\n", " BatchNorm2d-204 [-1, 1360, 14, 14] 2,720\n", " ReLU-205 [-1, 1360, 14, 14] 0\n", " Conv2d-206 [-1, 32, 14, 14] 391,712\n", " BatchNorm2d-207 [-1, 1392, 14, 14] 2,784\n", " ReLU-208 [-1, 1392, 14, 14] 0\n", " Conv2d-209 [-1, 32, 14, 14] 44,576\n", " BatchNorm2d-210 [-1, 1424, 14, 14] 2,848\n", " ReLU-211 [-1, 1424, 14, 14] 0\n", " Conv2d-212 [-1, 32, 14, 14] 410,144\n", " BatchNorm2d-213 [-1, 1456, 14, 14] 2,912\n", " ReLU-214 [-1, 1456, 14, 14] 0\n", " Conv2d-215 [-1, 32, 14, 14] 46,624\n", " BatchNorm2d-216 [-1, 1488, 14, 14] 2,976\n", " ReLU-217 [-1, 1488, 14, 14] 0\n", " Conv2d-218 [-1, 32, 14, 14] 428,576\n", " BatchNorm2d-219 [-1, 1520, 14, 14] 3,040\n", " ReLU-220 [-1, 1520, 14, 14] 0\n", " Conv2d-221 [-1, 32, 14, 14] 48,672\n", " BatchNorm2d-222 [-1, 1552, 14, 14] 3,104\n", " ReLU-223 [-1, 1552, 14, 14] 0\n", " Conv2d-224 [-1, 32, 14, 14] 447,008\n", " BatchNorm2d-225 [-1, 1584, 14, 14] 3,168\n", " ReLU-226 [-1, 1584, 14, 14] 0\n", " Conv2d-227 [-1, 32, 14, 14] 50,720\n", " BatchNorm2d-228 [-1, 1616, 14, 14] 3,232\n", " ReLU-229 [-1, 1616, 14, 14] 0\n", " Conv2d-230 [-1, 32, 14, 14] 465,440\n", " BatchNorm2d-231 [-1, 1648, 14, 14] 3,296\n", " ReLU-232 [-1, 1648, 14, 14] 0\n", " Conv2d-233 [-1, 32, 14, 14] 52,768\n", " BatchNorm2d-234 [-1, 1680, 14, 14] 3,360\n", " ReLU-235 [-1, 1680, 14, 14] 0\n", " Conv2d-236 [-1, 32, 14, 14] 483,872\n", " BatchNorm2d-237 [-1, 1712, 14, 14] 3,424\n", " ReLU-238 [-1, 1712, 14, 14] 0\n", " Conv2d-239 [-1, 32, 14, 14] 54,816\n", " BatchNorm2d-240 [-1, 1744, 14, 14] 3,488\n", " ReLU-241 [-1, 1744, 14, 14] 0\n", " Conv2d-242 [-1, 32, 14, 14] 502,304\n", " BatchNorm2d-243 [-1, 1776, 14, 14] 3,552\n", " ReLU-244 [-1, 1776, 14, 14] 0\n", " Conv2d-245 [-1, 32, 14, 14] 56,864\n", " BatchNorm2d-246 [-1, 1808, 14, 14] 3,616\n", " ReLU-247 [-1, 1808, 14, 14] 0\n", " Conv2d-248 [-1, 32, 14, 14] 520,736\n", " BatchNorm2d-249 [-1, 1840, 14, 14] 3,680\n", " ReLU-250 [-1, 1840, 14, 14] 0\n", " Conv2d-251 [-1, 32, 14, 14] 58,912\n", " BatchNorm2d-252 [-1, 1872, 14, 14] 3,744\n", " ReLU-253 [-1, 1872, 14, 14] 0\n", " Conv2d-254 [-1, 32, 14, 14] 539,168\n", " BatchNorm2d-255 [-1, 1904, 14, 14] 3,808\n", " ReLU-256 [-1, 1904, 14, 14] 0\n", " Conv2d-257 [-1, 32, 14, 14] 60,960\n", " BatchNorm2d-258 [-1, 1936, 14, 14] 3,872\n", " ReLU-259 [-1, 1936, 14, 14] 0\n", " Conv2d-260 [-1, 32, 14, 14] 557,600\n", " BatchNorm2d-261 [-1, 1968, 14, 14] 3,936\n", " ReLU-262 [-1, 1968, 14, 14] 0\n", " Conv2d-263 [-1, 32, 14, 14] 63,008\n", " BatchNorm2d-264 [-1, 2000, 14, 14] 4,000\n", " ReLU-265 [-1, 2000, 14, 14] 0\n", " Conv2d-266 [-1, 32, 14, 14] 576,032\n", " DenseBlock-267 [-1, 2032, 14, 14] 0\n", " BatchNorm2d-268 [-1, 2032, 14, 14] 4,064\n", " ReLU-269 [-1, 2032, 14, 14] 0\n", " Conv2d-270 [-1, 1016, 14, 14] 2,065,528\n", " AvgPool2d-271 [-1, 1016, 7, 7] 0\n", " BatchNorm2d-272 [-1, 1016, 7, 7] 2,032\n", " ReLU-273 [-1, 1016, 7, 7] 0\n", " Conv2d-274 [-1, 32, 7, 7] 32,544\n", " BatchNorm2d-275 [-1, 1048, 7, 7] 2,096\n", " ReLU-276 [-1, 1048, 7, 7] 0\n", " Conv2d-277 [-1, 32, 7, 7] 301,856\n", " BatchNorm2d-278 [-1, 1080, 7, 7] 2,160\n", " ReLU-279 [-1, 1080, 7, 7] 0\n", " Conv2d-280 [-1, 32, 7, 7] 34,592\n", " BatchNorm2d-281 [-1, 1112, 7, 7] 2,224\n", " ReLU-282 [-1, 1112, 7, 7] 0\n", " Conv2d-283 [-1, 32, 7, 7] 320,288\n", " BatchNorm2d-284 [-1, 1144, 7, 7] 2,288\n", " ReLU-285 [-1, 1144, 7, 7] 0\n", " Conv2d-286 [-1, 32, 7, 7] 36,640\n", " BatchNorm2d-287 [-1, 1176, 7, 7] 2,352\n", " ReLU-288 [-1, 1176, 7, 7] 0\n", " Conv2d-289 [-1, 32, 7, 7] 338,720\n", " BatchNorm2d-290 [-1, 1208, 7, 7] 2,416\n", " ReLU-291 [-1, 1208, 7, 7] 0\n", " Conv2d-292 [-1, 32, 7, 7] 38,688\n", " BatchNorm2d-293 [-1, 1240, 7, 7] 2,480\n", " ReLU-294 [-1, 1240, 7, 7] 0\n", " Conv2d-295 [-1, 32, 7, 7] 357,152\n", " BatchNorm2d-296 [-1, 1272, 7, 7] 2,544\n", " ReLU-297 [-1, 1272, 7, 7] 0\n", " Conv2d-298 [-1, 32, 7, 7] 40,736\n", " BatchNorm2d-299 [-1, 1304, 7, 7] 2,608\n", " ReLU-300 [-1, 1304, 7, 7] 0\n", " Conv2d-301 [-1, 32, 7, 7] 375,584\n", " BatchNorm2d-302 [-1, 1336, 7, 7] 2,672\n", " ReLU-303 [-1, 1336, 7, 7] 0\n", " Conv2d-304 [-1, 32, 7, 7] 42,784\n", " BatchNorm2d-305 [-1, 1368, 7, 7] 2,736\n", " ReLU-306 [-1, 1368, 7, 7] 0\n", " Conv2d-307 [-1, 32, 7, 7] 394,016\n", " BatchNorm2d-308 [-1, 1400, 7, 7] 2,800\n", " ReLU-309 [-1, 1400, 7, 7] 0\n", " Conv2d-310 [-1, 32, 7, 7] 44,832\n", " BatchNorm2d-311 [-1, 1432, 7, 7] 2,864\n", " ReLU-312 [-1, 1432, 7, 7] 0\n", " Conv2d-313 [-1, 32, 7, 7] 412,448\n", " BatchNorm2d-314 [-1, 1464, 7, 7] 2,928\n", " ReLU-315 [-1, 1464, 7, 7] 0\n", " Conv2d-316 [-1, 32, 7, 7] 46,880\n", " BatchNorm2d-317 [-1, 1496, 7, 7] 2,992\n", " ReLU-318 [-1, 1496, 7, 7] 0\n", " Conv2d-319 [-1, 32, 7, 7] 430,880\n", " BatchNorm2d-320 [-1, 1528, 7, 7] 3,056\n", " ReLU-321 [-1, 1528, 7, 7] 0\n", " Conv2d-322 [-1, 32, 7, 7] 48,928\n", " BatchNorm2d-323 [-1, 1560, 7, 7] 3,120\n", " ReLU-324 [-1, 1560, 7, 7] 0\n", " Conv2d-325 [-1, 32, 7, 7] 449,312\n", " BatchNorm2d-326 [-1, 1592, 7, 7] 3,184\n", " ReLU-327 [-1, 1592, 7, 7] 0\n", " Conv2d-328 [-1, 32, 7, 7] 50,976\n", " BatchNorm2d-329 [-1, 1624, 7, 7] 3,248\n", " ReLU-330 [-1, 1624, 7, 7] 0\n", " Conv2d-331 [-1, 32, 7, 7] 467,744\n", " BatchNorm2d-332 [-1, 1656, 7, 7] 3,312\n", " ReLU-333 [-1, 1656, 7, 7] 0\n", " Conv2d-334 [-1, 32, 7, 7] 53,024\n", " BatchNorm2d-335 [-1, 1688, 7, 7] 3,376\n", " ReLU-336 [-1, 1688, 7, 7] 0\n", " Conv2d-337 [-1, 32, 7, 7] 486,176\n", " BatchNorm2d-338 [-1, 1720, 7, 7] 3,440\n", " ReLU-339 [-1, 1720, 7, 7] 0\n", " Conv2d-340 [-1, 32, 7, 7] 55,072\n", " BatchNorm2d-341 [-1, 1752, 7, 7] 3,504\n", " ReLU-342 [-1, 1752, 7, 7] 0\n", " Conv2d-343 [-1, 32, 7, 7] 504,608\n", " BatchNorm2d-344 [-1, 1784, 7, 7] 3,568\n", " ReLU-345 [-1, 1784, 7, 7] 0\n", " Conv2d-346 [-1, 32, 7, 7] 57,120\n", " BatchNorm2d-347 [-1, 1816, 7, 7] 3,632\n", " ReLU-348 [-1, 1816, 7, 7] 0\n", " Conv2d-349 [-1, 32, 7, 7] 523,040\n", " BatchNorm2d-350 [-1, 1848, 7, 7] 3,696\n", " ReLU-351 [-1, 1848, 7, 7] 0\n", " Conv2d-352 [-1, 32, 7, 7] 59,168\n", " BatchNorm2d-353 [-1, 1880, 7, 7] 3,760\n", " ReLU-354 [-1, 1880, 7, 7] 0\n", " Conv2d-355 [-1, 32, 7, 7] 541,472\n", " BatchNorm2d-356 [-1, 1912, 7, 7] 3,824\n", " ReLU-357 [-1, 1912, 7, 7] 0\n", " Conv2d-358 [-1, 32, 7, 7] 61,216\n", " BatchNorm2d-359 [-1, 1944, 7, 7] 3,888\n", " ReLU-360 [-1, 1944, 7, 7] 0\n", " Conv2d-361 [-1, 32, 7, 7] 559,904\n", " BatchNorm2d-362 [-1, 1976, 7, 7] 3,952\n", " ReLU-363 [-1, 1976, 7, 7] 0\n", " Conv2d-364 [-1, 32, 7, 7] 63,264\n", " BatchNorm2d-365 [-1, 2008, 7, 7] 4,016\n", " ReLU-366 [-1, 2008, 7, 7] 0\n", " Conv2d-367 [-1, 32, 7, 7] 578,336\n", " DenseBlock-368 [-1, 2040, 7, 7] 0\n", " BatchNorm2d-369 [-1, 2040, 7, 7] 4,080\n", " ReLU-370 [-1, 2040, 7, 7] 0\n", "AdaptiveAvgPool2d-371 [-1, 2040, 1, 1] 0\n", " Flatten-372 [-1, 2040] 0\n", " Linear-373 [-1, 10] 20,410\n", "================================================================\n", "Total params: 23,245,586\n", "Trainable params: 23,245,586\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 633.18\n", "Params size (MB): 88.67\n", "Estimated Total Size (MB): 722.05\n", "----------------------------------------------------------------\n" ] } ], "source": [ "data = d2l.FashionMNIST(batch_size=32, resize=(224, 224))\n", "arch121 = ([[[1,0],[3,1]]*6,[[1,0],[3,1]]*12,[[1,0],[3,1]]*24,[[1,0],[3,1]]*16])\n", "densenet121 = DenseNet(lr=0.01, arch=arch121)\n", "densenet121.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "# print(count_parameters(model))\n", "summary(densenet121, (1, 224, 224))" ] }, { "cell_type": "code", "execution_count": null, "id": "9ea48de3-03d1-45b2-b6bd-4ed02261c066", "metadata": {}, "outputs": [], "source": [ "arch169 = ([[[1,0],[3,1]]*6,[[1,0],[3,1]]*12,[[1,0],[3,1]]*32,[[1,0],[3,1]]*32])\n", "densenet169 = DenseNet(lr=0.01, arch=arch169)" ] }, { "cell_type": "code", "execution_count": null, "id": "04cd066b-1198-47b1-818e-4f431deacfdb", "metadata": {}, "outputs": [], "source": [ "arch201 = ([[[1,0],[3,1]]*6,[[1,0],[3,1]]*12,[[1,0],[3,1]]*48,[[1,0],[3,1]]*32])\n", "densenet201 = DenseNet(lr=0.01, arch=arch201)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9d27437-6879-4f17-a1dd-420004b29cd1", "metadata": {}, "outputs": [], "source": [ "arch264 = ([[[1,0],[3,1]]*6,[[1,0],[3,1]]*12,[[1,0],[3,1]]*64,[[1,0],[3,1]]*48])\n", "densenet264 = DenseNet(lr=0.01, arch=arch264)" ] }, { "cell_type": "markdown", "id": "832e3e07-d251-46bd-8570-12bb6ed14a02", "metadata": {}, "source": [ "# 5. Design an MLP-based model by applying the DenseNet idea. Apply it to the housing price prediction task in Section 5.7." ] }, { "cell_type": "code", "execution_count": 1, "id": "3a680253-abcb-4f2d-9db3-2e10111ae507", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import time\n", "from tqdm import tqdm\n", "import sys\n", "import torch\n", "import torchvision\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "import warnings\n", "import matplotlib.pyplot as plt\n", "import cProfile\n", "sys.path.append('/home/jovyan/work/d2l_solutions/notebooks/exercises/d2l_utils/')\n", "import d2l\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "class KaggleHouse(d2l.DataModule):\n", " def __init__(self, batch_size, train=None, val=None):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " if self.train is None:\n", " self.raw_train = pd.read_csv(d2l.download(d2l.DATA_URL+ 'kaggle_house_pred_train.csv', self.root,\n", " sha1_hash='585e9cc93e70b39160e7921475f9bcd7d31219ce'))\n", " self.raw_val = pd.read_csv(d2l.download(\n", " d2l.DATA_URL + 'kaggle_house_pred_test.csv', self.root,\n", " sha1_hash='fa19780a7b011d9b009e8bff8e99922a8ee2eb90'))\n", " \n", " def preprocess(self, std_flag=True):\n", " label = 'SalePrice'\n", " features = pd.concat((self.raw_train.drop(columns=['Id',label]),\n", " self.raw_val.drop(columns=['Id'])))\n", " numeric_features = features.dtypes[features.dtypes!='object'].index\n", " if std_flag:\n", " features[numeric_features] = features[numeric_features].apply(lambda x: (x-x.mean())/x.std())\n", " features[numeric_features] = features[numeric_features].fillna(0)\n", " features = pd.get_dummies(features, dummy_na=True)\n", " self.train = features[:self.raw_train.shape[0]].copy()\n", " self.train[label] = self.raw_train[label]\n", " self.val = features[self.raw_train.shape[0]:].copy()\n", " \n", " def get_dataloader(self, train):\n", " label = 'SalePrice'\n", " data = self.train if train else self.val\n", " if label not in data:\n", " return\n", " get_tensor = lambda x: torch.tensor(x.values.astype(float), dtype=torch.float32)\n", " # tensors = (get_tensor(data.drop(columns=[label])),\n", " # torch.log(get_tensor(data[label])).reshape(-1,1))\n", " tensors = (get_tensor(data.drop(columns=[label])), # X\n", " torch.log(get_tensor(data[label])).reshape((-1, 1))) # Y\n", " return self.get_tensorloader(tensors, train)\n", " \n", "def k_fold_data(data,k):\n", " rets = []\n", " fold_size = data.train.shape[0] // k\n", " for j in range(k):\n", " idx = range(j*fold_size,(j+1)*fold_size)\n", " rets.append(KaggleHouse(data.batch_size,data.train.drop(index=idx),data.train.iloc[idx]))\n", " return rets\n", "\n", "def k_fold(trainer, data, k, ModelClass,hparams,plot_flag=True):\n", " val_loss, models = [], []\n", " for i, data_fold in enumerate(k_fold_data(data,k)):\n", " model = ModelClass(**hparams)\n", " model.board.yscale='log'\n", " if not plot_flag or i != 0:\n", " model.board.display=False\n", " trainer.fit(model,data_fold)\n", " val_loss.append(float(model.board.data['val_loss'][-1].y))\n", " models.append(model)\n", " avg_val_loss = sum(val_loss)/len(val_loss)\n", " print(f'average validation log mse = {avg_val_loss}, params:{hparams}')\n", " return models, avg_val_loss\n", "\n", "\n", "\n", "class HouseResMLP(d2l.LinearRegression):\n", " def __init__(self, num_outputs, num_hiddens, lr, dropouts, weight_decay):\n", " super().__init__(lr)\n", " self.save_hyperparameters()\n", " layers = [nn.Flatten()]\n", " for i in range(len(num_hiddens)):\n", " layers.append(nn.Sequential(nn.LazyLinear(num_hiddens[i]),\n", " nn.ReLU(),\n", " nn.Dropout(dropouts[i]),\n", " nn.LazyBatchNorm1d(),\n", " ))\n", " layers.append(nn.LazyLinear(num_outputs))\n", " self.net = nn.Sequential(*layers)\n", " \n", " def forward(self, X):\n", " X = self.net[0](X)\n", " for blk in self.net[1:-1]:\n", " Y = blk(X)\n", " # Concatenate input and output of each block along the channels\n", " X = torch.cat((X, Y), dim=1)\n", " return self.net[-1](X)\n", " \n", "# class HouseDenseBlock(nn.Module):\n", "# def __init__(self, num_hiddens):\n", "# super().__init__()\n", "# layer = []\n", "# for i in range(len(num_hiddens)):\n", "# layer.append(nn.Sequential(nn.LazyLinear(num_hiddens[i]),\n", "# nn.LazyBatchNorm1d(), nn.ReLU(),\n", "# ))\n", "# self.net = nn.Sequential(*layer)\n", "\n", "# def forward(self, X):\n", "# for blk in self.net:\n", "# Y = blk(X)\n", "# # Concatenate input and output of each block along the channels\n", "# X = torch.cat((X, Y), dim=1)\n", "# return X\n", " \n", "# def transition_block():\n", "# return nn.Sequential(\n", "# nn.LazyBatchNorm1d(), nn.ReLU(),\n", "# nn.AvgPool1d(kernel_size=2, stride=2))\n", "\n", "# class HouseResMLP(d2l.LinearRegression):\n", "# def __init__(self, num_outputs, arch, lr, dropouts, weight_decay):\n", "# super().__init__(lr)\n", "# self.save_hyperparameters()\n", "# layers = [nn.Flatten()]\n", "# for num_hiddens in arch:\n", "# layers.append(HouseDenseBlock(num_hiddens))\n", "# # layers.append(nn.LazyLinear(sum(num_hiddens)//4))\n", "# layers.append(nn.LazyLinear(num_outputs))\n", "# self.net = nn.Sequential(*layers)\n", " \n", "# def forward(self, X):\n", "# return self.net(X)\n", " \n", "# def configure_optimizers(self):\n", "# return torch.optim.SGD(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)" ] }, { "cell_type": "code", "execution_count": 94, "id": "7c927898-4370-41e4-9380-365511a5378e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Flatten-1 [-1, 80] 0\n", " Linear-2 [-1, 64] 5,184\n", " ReLU-3 [-1, 64] 0\n", " Dropout-4 [-1, 64] 0\n", " BatchNorm1d-5 [-1, 64] 128\n", " Linear-6 [-1, 32] 4,640\n", " ReLU-7 [-1, 32] 0\n", " Dropout-8 [-1, 32] 0\n", " BatchNorm1d-9 [-1, 32] 64\n", " Linear-10 [-1, 16] 2,832\n", " ReLU-11 [-1, 16] 0\n", " Dropout-12 [-1, 16] 0\n", " BatchNorm1d-13 [-1, 16] 32\n", " Linear-14 [-1, 8] 1,544\n", " ReLU-15 [-1, 8] 0\n", " Dropout-16 [-1, 8] 0\n", " BatchNorm1d-17 [-1, 8] 16\n", " Linear-18 [-1, 1] 201\n", "================================================================\n", "Total params: 14,641\n", "Trainable params: 14,641\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.00\n", "Params size (MB): 0.06\n", "Estimated Total Size (MB): 0.06\n", "----------------------------------------------------------------\n" ] } ], "source": [ "hparams = {'dropouts': [0]*5,\n", " 'lr': 0.01,\n", " 'num_hiddens': [64,32,16,8],\n", " 'num_outputs': 1,\n", " 'weight_decay': 0}\n", "model = HouseResMLP(**hparams)\n", "summary(model,(1,80))" ] }, { "cell_type": "code", "execution_count": 2, "id": "2aebb66b-836b-4868-aa6e-f668f205ee51", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1460, 81) (1459, 80)\n" ] } ], "source": [ "data = KaggleHouse(batch_size=64)\n", "print(data.raw_train.shape, data.raw_val.shape)\n", "data.preprocess()" ] }, { "cell_type": "code", "execution_count": null, "id": "d7ca94d0-3c19-4ac0-ae57-6899791b15f1", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-09-01T10:37:12.271703\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hparams = {'dropouts': [0]*5,\n", " 'lr': 0.005,\n", " 'num_hiddens': [1,1,1],\n", " # 'num_hiddens': [128,64,32], \n", " 'num_outputs': 1,\n", " 'weight_decay': 0}\n", "trainer = d2l.Trainer(max_epochs=100)\n", "models,avg_val_loss = k_fold(trainer, data, k=5,ModelClass=HouseResMLP,hparams=hparams,plot_flag=True)" ] }, { "cell_type": "code", "execution_count": 17, "id": "55217ddf-cd72-467e-b510-ad3594d6e632", "metadata": { "tags": [] }, "outputs": [], "source": [ "preds = [model(torch.tensor(data.val.values.astype(float), dtype=torch.float32)) for model in models]\n", "ensemble_preds = torch.exp(torch.cat(preds,1)).mean(1)\n", "submission = pd.DataFrame({'Id':data.raw_val.Id,'SalePrice':ensemble_preds.detach().numpy()})\n", "submission.to_csv('submission.csv', index=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:d2l]", "language": "python", "name": "conda-env-d2l-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }