{
"cells": [
{
"cell_type": "markdown",
"id": "5260fd99",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 批量规范化\n",
"\n",
"从零实现"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "042456d3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:16.084651Z",
"iopub.status.busy": "2023-08-18T07:16:16.083898Z",
"iopub.status.idle": "2023-08-18T07:16:18.925904Z",
"shell.execute_reply": "2023-08-18T07:16:18.924662Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"\n",
"def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
" if not torch.is_grad_enabled():\n",
" X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
" else:\n",
" assert len(X.shape) in (2, 4)\n",
" if len(X.shape) == 2:\n",
" mean = X.mean(dim=0)\n",
" var = ((X - mean) ** 2).mean(dim=0)\n",
" else:\n",
" mean = X.mean(dim=(0, 2, 3), keepdim=True)\n",
" var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n",
" X_hat = (X - mean) / torch.sqrt(var + eps)\n",
" moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
" moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
" Y = gamma * X_hat + beta\n",
" return Y, moving_mean.data, moving_var.data"
]
},
{
"cell_type": "markdown",
"id": "935c0000",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"创建一个正确的`BatchNorm`层"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f9b0ce07",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.930541Z",
"iopub.status.busy": "2023-08-18T07:16:18.929642Z",
"iopub.status.idle": "2023-08-18T07:16:18.937402Z",
"shell.execute_reply": "2023-08-18T07:16:18.936597Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BatchNorm(nn.Module):\n",
" def __init__(self, num_features, num_dims):\n",
" super().__init__()\n",
" if num_dims == 2:\n",
" shape = (1, num_features)\n",
" else:\n",
" shape = (1, num_features, 1, 1)\n",
" self.gamma = nn.Parameter(torch.ones(shape))\n",
" self.beta = nn.Parameter(torch.zeros(shape))\n",
" self.moving_mean = torch.zeros(shape)\n",
" self.moving_var = torch.ones(shape)\n",
"\n",
" def forward(self, X):\n",
" if self.moving_mean.device != X.device:\n",
" self.moving_mean = self.moving_mean.to(X.device)\n",
" self.moving_var = self.moving_var.to(X.device)\n",
" Y, self.moving_mean, self.moving_var = batch_norm(\n",
" X, self.gamma, self.beta, self.moving_mean,\n",
" self.moving_var, eps=1e-5, momentum=0.9)\n",
" return Y"
]
},
{
"cell_type": "markdown",
"id": "3eca2c7a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"应用`BatchNorm`\n",
"于LeNet模型"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "89ca8ab0",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.940903Z",
"iopub.status.busy": "2023-08-18T07:16:18.940366Z",
"iopub.status.idle": "2023-08-18T07:16:18.966572Z",
"shell.execute_reply": "2023-08-18T07:16:18.965740Z"
},
"origin_pos": 12,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),\n",
" nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),\n",
" nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),\n",
" nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),\n",
" nn.Linear(84, 10))"
]
},
{
"cell_type": "markdown",
"id": "1edfe76d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"在Fashion-MNIST数据集上训练网络"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a0c4988d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:16:18.970436Z",
"iopub.status.busy": "2023-08-18T07:16:18.969896Z",
"iopub.status.idle": "2023-08-18T07:17:04.740786Z",
"shell.execute_reply": "2023-08-18T07:17:04.739449Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss 0.273, train acc 0.899, test acc 0.807\n",
"32293.9 examples/sec on cuda:0\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"