{ "cells": [ { "cell_type": "markdown", "id": "5260fd99", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# 批量规范化\n", "\n", "从零实现" ] }, { "cell_type": "code", "execution_count": 1, "id": "042456d3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:16:16.084651Z", "iopub.status.busy": "2023-08-18T07:16:16.083898Z", "iopub.status.idle": "2023-08-18T07:16:18.925904Z", "shell.execute_reply": "2023-08-18T07:16:18.924662Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "\n", "def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n", " if not torch.is_grad_enabled():\n", " X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n", " else:\n", " assert len(X.shape) in (2, 4)\n", " if len(X.shape) == 2:\n", " mean = X.mean(dim=0)\n", " var = ((X - mean) ** 2).mean(dim=0)\n", " else:\n", " mean = X.mean(dim=(0, 2, 3), keepdim=True)\n", " var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n", " X_hat = (X - mean) / torch.sqrt(var + eps)\n", " moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n", " moving_var = momentum * moving_var + (1.0 - momentum) * var\n", " Y = gamma * X_hat + beta\n", " return Y, moving_mean.data, moving_var.data" ] }, { "cell_type": "markdown", "id": "935c0000", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "创建一个正确的`BatchNorm`层" ] }, { "cell_type": "code", "execution_count": 2, "id": "f9b0ce07", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:16:18.930541Z", "iopub.status.busy": "2023-08-18T07:16:18.929642Z", "iopub.status.idle": "2023-08-18T07:16:18.937402Z", "shell.execute_reply": "2023-08-18T07:16:18.936597Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BatchNorm(nn.Module):\n", " def __init__(self, num_features, num_dims):\n", " super().__init__()\n", " if num_dims == 2:\n", " shape = (1, num_features)\n", " else:\n", " shape = (1, num_features, 1, 1)\n", " self.gamma = nn.Parameter(torch.ones(shape))\n", " self.beta = nn.Parameter(torch.zeros(shape))\n", " self.moving_mean = torch.zeros(shape)\n", " self.moving_var = torch.ones(shape)\n", "\n", " def forward(self, X):\n", " if self.moving_mean.device != X.device:\n", " self.moving_mean = self.moving_mean.to(X.device)\n", " self.moving_var = self.moving_var.to(X.device)\n", " Y, self.moving_mean, self.moving_var = batch_norm(\n", " X, self.gamma, self.beta, self.moving_mean,\n", " self.moving_var, eps=1e-5, momentum=0.9)\n", " return Y" ] }, { "cell_type": "markdown", "id": "3eca2c7a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "应用`BatchNorm`\n", "于LeNet模型" ] }, { "cell_type": "code", "execution_count": 3, "id": "89ca8ab0", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:16:18.940903Z", "iopub.status.busy": "2023-08-18T07:16:18.940366Z", "iopub.status.idle": "2023-08-18T07:16:18.966572Z", "shell.execute_reply": "2023-08-18T07:16:18.965740Z" }, "origin_pos": 12, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = nn.Sequential(\n", " nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),\n", " nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),\n", " nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),\n", " nn.Linear(84, 10))" ] }, { "cell_type": "markdown", "id": "1edfe76d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在Fashion-MNIST数据集上训练网络" ] }, { "cell_type": "code", "execution_count": 4, "id": "a0c4988d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:16:18.970436Z", "iopub.status.busy": "2023-08-18T07:16:18.969896Z", "iopub.status.idle": "2023-08-18T07:17:04.740786Z", "shell.execute_reply": "2023-08-18T07:17:04.739449Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.273, train acc 0.899, test acc 0.807\n", "32293.9 examples/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:17:04.694266\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "lr, num_epochs, batch_size = 1.0, 10, 256\n", "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n", "d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ] }, { "cell_type": "markdown", "id": "3cd6e49b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "拉伸参数`gamma`和偏移参数`beta`" ] }, { "cell_type": "code", "execution_count": 5, "id": "055a3583", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:04.745528Z", "iopub.status.busy": "2023-08-18T07:17:04.744678Z", "iopub.status.idle": "2023-08-18T07:17:04.755775Z", "shell.execute_reply": "2023-08-18T07:17:04.754582Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(tensor([0.4863, 2.8573, 2.3190, 4.3188, 3.8588, 1.7942], device='cuda:0',\n", " grad_fn=),\n", " tensor([-0.0124, 1.4839, -1.7753, 2.3564, -3.8801, -2.1589], device='cuda:0',\n", " grad_fn=))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))" ] }, { "cell_type": "markdown", "id": "cc89ef77", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "简明实现" ] }, { "cell_type": "code", "execution_count": 6, "id": "b8604933", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:04.759625Z", "iopub.status.busy": "2023-08-18T07:17:04.758859Z", "iopub.status.idle": "2023-08-18T07:17:04.769251Z", "shell.execute_reply": "2023-08-18T07:17:04.768076Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = nn.Sequential(\n", " nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2),\n", " nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),\n", " nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),\n", " nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),\n", " nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),\n", " nn.Linear(84, 10))" ] }, { "cell_type": "markdown", "id": "b96b59b2", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "使用相同超参数来训练模型" ] }, { "cell_type": "code", "execution_count": 7, "id": "add53e76", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:17:04.772567Z", "iopub.status.busy": "2023-08-18T07:17:04.772282Z", "iopub.status.idle": "2023-08-18T07:17:54.677901Z", "shell.execute_reply": "2023-08-18T07:17:54.676931Z" }, "origin_pos": 29, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.267, train acc 0.902, test acc 0.708\n", "50597.3 examples/sec on cuda:0\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:17:54.611775\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }