{ "cells": [ { "cell_type": "markdown", "id": "5d0a7769", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Batch Normalization\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "f7b44765", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:58:59.890305Z", "iopub.status.busy": "2023-08-18T19:58:59.889455Z", "iopub.status.idle": "2023-08-18T19:59:03.088529Z", "shell.execute_reply": "2023-08-18T19:59:03.087539Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "e32814aa", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "Implementation from Scratch" ] }, { "cell_type": "code", "execution_count": 2, "id": "9a79b8f2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:59:03.092523Z", "iopub.status.busy": "2023-08-18T19:59:03.092120Z", "iopub.status.idle": "2023-08-18T19:59:03.100348Z", "shell.execute_reply": "2023-08-18T19:59:03.099493Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n", " if not torch.is_grad_enabled():\n", " X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n", " else:\n", " assert len(X.shape) in (2, 4)\n", " if len(X.shape) == 2:\n", " mean = X.mean(dim=0)\n", " var = ((X - mean) ** 2).mean(dim=0)\n", " else:\n", " mean = X.mean(dim=(0, 2, 3), keepdim=True)\n", " var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n", " X_hat = (X - mean) / torch.sqrt(var + eps)\n", " moving_mean = (1.0 - momentum) * moving_mean + momentum * mean\n", " moving_var = (1.0 - momentum) * moving_var + momentum * var\n", " Y = gamma * X_hat + beta\n", " return Y, moving_mean.data, moving_var.data" ] }, { "cell_type": "markdown", "id": "ac9f8a60", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Create a proper `BatchNorm` layer" ] }, { "cell_type": "code", "execution_count": 3, "id": "8a591dd1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:59:03.103959Z", "iopub.status.busy": "2023-08-18T19:59:03.103597Z", "iopub.status.idle": "2023-08-18T19:59:03.113624Z", "shell.execute_reply": "2023-08-18T19:59:03.112645Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BatchNorm(nn.Module):\n", " def __init__(self, num_features, num_dims):\n", " super().__init__()\n", " if num_dims == 2:\n", " shape = (1, num_features)\n", " else:\n", " shape = (1, num_features, 1, 1)\n", " self.gamma = nn.Parameter(torch.ones(shape))\n", " self.beta = nn.Parameter(torch.zeros(shape))\n", " self.moving_mean = torch.zeros(shape)\n", " self.moving_var = torch.ones(shape)\n", "\n", " def forward(self, X):\n", " if self.moving_mean.device != X.device:\n", " self.moving_mean = self.moving_mean.to(X.device)\n", " self.moving_var = self.moving_var.to(X.device)\n", " Y, self.moving_mean, self.moving_var = batch_norm(\n", " X, self.gamma, self.beta, self.moving_mean,\n", " self.moving_var, eps=1e-5, momentum=0.1)\n", " return Y" ] }, { "cell_type": "markdown", "id": "e2165861", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "LeNet with Batch Normalization" ] }, { "cell_type": "code", "execution_count": 4, "id": "21c51c36", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:59:03.118112Z", "iopub.status.busy": "2023-08-18T19:59:03.117737Z", "iopub.status.idle": "2023-08-18T19:59:03.124711Z", "shell.execute_reply": "2023-08-18T19:59:03.123881Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BNLeNetScratch(d2l.Classifier):\n", " def __init__(self, lr=0.1, num_classes=10):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(\n", " nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),\n", " nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),\n", " nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Flatten(), nn.LazyLinear(120),\n", " BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),\n", " BatchNorm(84, num_dims=2), nn.Sigmoid(),\n", " nn.LazyLinear(num_classes))" ] }, { "cell_type": "markdown", "id": "ca66d52a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Train our network on the Fashion-MNIST dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "064cdd64", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:59:03.127886Z", "iopub.status.busy": "2023-08-18T19:59:03.127595Z", "iopub.status.idle": "2023-08-18T20:00:16.870229Z", "shell.execute_reply": "2023-08-18T20:00:16.869283Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T20:00:16.743614\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128)\n", "model = BNLeNetScratch(lr=0.1)\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "trainer.fit(model, data)" ] }, { "cell_type": "markdown", "id": "5d788d3d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Have a look at the scale parameter `gamma`\n", "and the shift parameter `beta`" ] }, { "cell_type": "code", "execution_count": 6, "id": "4969fdc2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:00:16.875213Z", "iopub.status.busy": "2023-08-18T20:00:16.874610Z", "iopub.status.idle": "2023-08-18T20:00:16.971921Z", "shell.execute_reply": "2023-08-18T20:00:16.970745Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([1.4334, 1.9905, 1.8584, 2.0740, 2.0522, 1.8877], device='cuda:0',\n", " grad_fn=),\n", " tensor([ 0.7354, -1.3538, -0.2567, -0.9991, -0.3028, 1.3125], device='cuda:0',\n", " grad_fn=))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.net[1].gamma.reshape((-1,)), model.net[1].beta.reshape((-1,))" ] }, { "cell_type": "markdown", "id": "a601daa9", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Concise Implementation" ] }, { "cell_type": "code", "execution_count": 7, "id": "ef2ab147", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:00:16.975625Z", "iopub.status.busy": "2023-08-18T20:00:16.975018Z", "iopub.status.idle": "2023-08-18T20:00:16.981373Z", "shell.execute_reply": "2023-08-18T20:00:16.980550Z" }, "origin_pos": 30, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BNLeNet(d2l.Classifier):\n", " def __init__(self, lr=0.1, num_classes=10):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.net = nn.Sequential(\n", " nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),\n", " nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),\n", " nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),\n", " nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),\n", " nn.Sigmoid(), nn.LazyLinear(num_classes))" ] }, { "cell_type": "markdown", "id": "0a6d4742", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Use the same hyperparameters to train our model" ] }, { "cell_type": "code", "execution_count": 8, "id": "0d6aaf49", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T20:00:16.984898Z", "iopub.status.busy": "2023-08-18T20:00:16.984364Z", "iopub.status.idle": "2023-08-18T20:01:21.082406Z", "shell.execute_reply": "2023-08-18T20:01:21.081474Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T20:01:20.955982\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n", "data = d2l.FashionMNIST(batch_size=128)\n", "model = BNLeNet(lr=0.1)\n", "model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n", "trainer.fit(model, data)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }