{
"cells": [
{
"cell_type": "markdown",
"id": "5d0a7769",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Batch Normalization\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f7b44765",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:58:59.890305Z",
"iopub.status.busy": "2023-08-18T19:58:59.889455Z",
"iopub.status.idle": "2023-08-18T19:59:03.088529Z",
"shell.execute_reply": "2023-08-18T19:59:03.087539Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "e32814aa",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Implementation from Scratch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9a79b8f2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.092523Z",
"iopub.status.busy": "2023-08-18T19:59:03.092120Z",
"iopub.status.idle": "2023-08-18T19:59:03.100348Z",
"shell.execute_reply": "2023-08-18T19:59:03.099493Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
" if not torch.is_grad_enabled():\n",
" X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
" else:\n",
" assert len(X.shape) in (2, 4)\n",
" if len(X.shape) == 2:\n",
" mean = X.mean(dim=0)\n",
" var = ((X - mean) ** 2).mean(dim=0)\n",
" else:\n",
" mean = X.mean(dim=(0, 2, 3), keepdim=True)\n",
" var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n",
" X_hat = (X - mean) / torch.sqrt(var + eps)\n",
" moving_mean = (1.0 - momentum) * moving_mean + momentum * mean\n",
" moving_var = (1.0 - momentum) * moving_var + momentum * var\n",
" Y = gamma * X_hat + beta\n",
" return Y, moving_mean.data, moving_var.data"
]
},
{
"cell_type": "markdown",
"id": "ac9f8a60",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Create a proper `BatchNorm` layer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8a591dd1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.103959Z",
"iopub.status.busy": "2023-08-18T19:59:03.103597Z",
"iopub.status.idle": "2023-08-18T19:59:03.113624Z",
"shell.execute_reply": "2023-08-18T19:59:03.112645Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BatchNorm(nn.Module):\n",
" def __init__(self, num_features, num_dims):\n",
" super().__init__()\n",
" if num_dims == 2:\n",
" shape = (1, num_features)\n",
" else:\n",
" shape = (1, num_features, 1, 1)\n",
" self.gamma = nn.Parameter(torch.ones(shape))\n",
" self.beta = nn.Parameter(torch.zeros(shape))\n",
" self.moving_mean = torch.zeros(shape)\n",
" self.moving_var = torch.ones(shape)\n",
"\n",
" def forward(self, X):\n",
" if self.moving_mean.device != X.device:\n",
" self.moving_mean = self.moving_mean.to(X.device)\n",
" self.moving_var = self.moving_var.to(X.device)\n",
" Y, self.moving_mean, self.moving_var = batch_norm(\n",
" X, self.gamma, self.beta, self.moving_mean,\n",
" self.moving_var, eps=1e-5, momentum=0.1)\n",
" return Y"
]
},
{
"cell_type": "markdown",
"id": "e2165861",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"LeNet with Batch Normalization"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "21c51c36",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.118112Z",
"iopub.status.busy": "2023-08-18T19:59:03.117737Z",
"iopub.status.idle": "2023-08-18T19:59:03.124711Z",
"shell.execute_reply": "2023-08-18T19:59:03.123881Z"
},
"origin_pos": 17,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BNLeNetScratch(d2l.Classifier):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(\n",
" nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Flatten(), nn.LazyLinear(120),\n",
" BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),\n",
" BatchNorm(84, num_dims=2), nn.Sigmoid(),\n",
" nn.LazyLinear(num_classes))"
]
},
{
"cell_type": "markdown",
"id": "ca66d52a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Train our network on the Fashion-MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "064cdd64",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.127886Z",
"iopub.status.busy": "2023-08-18T19:59:03.127595Z",
"iopub.status.idle": "2023-08-18T20:00:16.870229Z",
"shell.execute_reply": "2023-08-18T20:00:16.869283Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"