{
"cells": [
{
"cell_type": "markdown",
"id": "5260fd99",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 批量规范化\n",
"\n",
"从零实现"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "042456d3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:16.084651Z",
"iopub.status.busy": "2023-08-18T07:16:16.083898Z",
"iopub.status.idle": "2023-08-18T07:16:18.925904Z",
"shell.execute_reply": "2023-08-18T07:16:18.924662Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"\n",
"def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
" if not torch.is_grad_enabled():\n",
" X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
" else:\n",
" assert len(X.shape) in (2, 4)\n",
" if len(X.shape) == 2:\n",
" mean = X.mean(dim=0)\n",
" var = ((X - mean) ** 2).mean(dim=0)\n",
" else:\n",
" mean = X.mean(dim=(0, 2, 3), keepdim=True)\n",
" var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n",
" X_hat = (X - mean) / torch.sqrt(var + eps)\n",
" moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
" moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
" Y = gamma * X_hat + beta\n",
" return Y, moving_mean.data, moving_var.data"
]
},
{
"cell_type": "markdown",
"id": "935c0000",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"创建一个正确的`BatchNorm`层"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f9b0ce07",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.930541Z",
"iopub.status.busy": "2023-08-18T07:16:18.929642Z",
"iopub.status.idle": "2023-08-18T07:16:18.937402Z",
"shell.execute_reply": "2023-08-18T07:16:18.936597Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BatchNorm(nn.Module):\n",
" def __init__(self, num_features, num_dims):\n",
" super().__init__()\n",
" if num_dims == 2:\n",
" shape = (1, num_features)\n",
" else:\n",
" shape = (1, num_features, 1, 1)\n",
" self.gamma = nn.Parameter(torch.ones(shape))\n",
" self.beta = nn.Parameter(torch.zeros(shape))\n",
" self.moving_mean = torch.zeros(shape)\n",
" self.moving_var = torch.ones(shape)\n",
"\n",
" def forward(self, X):\n",
" if self.moving_mean.device != X.device:\n",
" self.moving_mean = self.moving_mean.to(X.device)\n",
" self.moving_var = self.moving_var.to(X.device)\n",
" Y, self.moving_mean, self.moving_var = batch_norm(\n",
" X, self.gamma, self.beta, self.moving_mean,\n",
" self.moving_var, eps=1e-5, momentum=0.9)\n",
" return Y"
]
},
{
"cell_type": "markdown",
"id": "3eca2c7a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"应用`BatchNorm`\n",
"于LeNet模型"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "89ca8ab0",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.940903Z",
"iopub.status.busy": "2023-08-18T07:16:18.940366Z",
"iopub.status.idle": "2023-08-18T07:16:18.966572Z",
"shell.execute_reply": "2023-08-18T07:16:18.965740Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),\n",
" nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),\n",
" nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),\n",
" nn.Linear(84, 10))"
]
},
{
"cell_type": "markdown",
"id": "1edfe76d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"在Fashion-MNIST数据集上训练网络"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a0c4988d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.970436Z",
"iopub.status.busy": "2023-08-18T07:16:18.969896Z",
"iopub.status.idle": "2023-08-18T07:17:04.740786Z",
"shell.execute_reply": "2023-08-18T07:17:04.739449Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.273, train acc 0.899, test acc 0.807\n",
"32293.9 examples/sec on cuda:0\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"lr, num_epochs, batch_size = 1.0, 10, 256\n",
"train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
"d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
]
},
{
"cell_type": "markdown",
"id": "3cd6e49b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"拉伸参数`gamma`和偏移参数`beta`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "055a3583",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:17:04.745528Z",
"iopub.status.busy": "2023-08-18T07:17:04.744678Z",
"iopub.status.idle": "2023-08-18T07:17:04.755775Z",
"shell.execute_reply": "2023-08-18T07:17:04.754582Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.4863, 2.8573, 2.3190, 4.3188, 3.8588, 1.7942], device='cuda:0',\n",
" grad_fn=),\n",
" tensor([-0.0124, 1.4839, -1.7753, 2.3564, -3.8801, -2.1589], device='cuda:0',\n",
" grad_fn=))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))"
]
},
{
"cell_type": "markdown",
"id": "cc89ef77",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"简明实现"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b8604933",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:17:04.759625Z",
"iopub.status.busy": "2023-08-18T07:17:04.758859Z",
"iopub.status.idle": "2023-08-18T07:17:04.769251Z",
"shell.execute_reply": "2023-08-18T07:17:04.768076Z"
},
"origin_pos": 25,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),\n",
" nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),\n",
" nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),\n",
" nn.Linear(84, 10))"
]
},
{
"cell_type": "markdown",
"id": "b96b59b2",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"使用相同超参数来训练模型"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "add53e76",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:17:04.772567Z",
"iopub.status.busy": "2023-08-18T07:17:04.772282Z",
"iopub.status.idle": "2023-08-18T07:17:54.677901Z",
"shell.execute_reply": "2023-08-18T07:17:54.676931Z"
},
"origin_pos": 29,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.267, train acc 0.902, test acc 0.708\n",
"50597.3 examples/sec on cuda:0\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}