{
"cells": [
{
"cell_type": "markdown",
"id": "5d0a7769",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Batch Normalization\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f7b44765",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:58:59.890305Z",
"iopub.status.busy": "2023-08-18T19:58:59.889455Z",
"iopub.status.idle": "2023-08-18T19:59:03.088529Z",
"shell.execute_reply": "2023-08-18T19:59:03.087539Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "e32814aa",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"Implementation from Scratch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9a79b8f2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.092523Z",
"iopub.status.busy": "2023-08-18T19:59:03.092120Z",
"iopub.status.idle": "2023-08-18T19:59:03.100348Z",
"shell.execute_reply": "2023-08-18T19:59:03.099493Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
" if not torch.is_grad_enabled():\n",
" X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
" else:\n",
" assert len(X.shape) in (2, 4)\n",
" if len(X.shape) == 2:\n",
" mean = X.mean(dim=0)\n",
" var = ((X - mean) ** 2).mean(dim=0)\n",
" else:\n",
" mean = X.mean(dim=(0, 2, 3), keepdim=True)\n",
" var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)\n",
" X_hat = (X - mean) / torch.sqrt(var + eps)\n",
" moving_mean = (1.0 - momentum) * moving_mean + momentum * mean\n",
" moving_var = (1.0 - momentum) * moving_var + momentum * var\n",
" Y = gamma * X_hat + beta\n",
" return Y, moving_mean.data, moving_var.data"
]
},
{
"cell_type": "markdown",
"id": "ac9f8a60",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Create a proper `BatchNorm` layer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8a591dd1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.103959Z",
"iopub.status.busy": "2023-08-18T19:59:03.103597Z",
"iopub.status.idle": "2023-08-18T19:59:03.113624Z",
"shell.execute_reply": "2023-08-18T19:59:03.112645Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BatchNorm(nn.Module):\n",
" def __init__(self, num_features, num_dims):\n",
" super().__init__()\n",
" if num_dims == 2:\n",
" shape = (1, num_features)\n",
" else:\n",
" shape = (1, num_features, 1, 1)\n",
" self.gamma = nn.Parameter(torch.ones(shape))\n",
" self.beta = nn.Parameter(torch.zeros(shape))\n",
" self.moving_mean = torch.zeros(shape)\n",
" self.moving_var = torch.ones(shape)\n",
"\n",
" def forward(self, X):\n",
" if self.moving_mean.device != X.device:\n",
" self.moving_mean = self.moving_mean.to(X.device)\n",
" self.moving_var = self.moving_var.to(X.device)\n",
" Y, self.moving_mean, self.moving_var = batch_norm(\n",
" X, self.gamma, self.beta, self.moving_mean,\n",
" self.moving_var, eps=1e-5, momentum=0.1)\n",
" return Y"
]
},
{
"cell_type": "markdown",
"id": "e2165861",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"LeNet with Batch Normalization"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "21c51c36",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.118112Z",
"iopub.status.busy": "2023-08-18T19:59:03.117737Z",
"iopub.status.idle": "2023-08-18T19:59:03.124711Z",
"shell.execute_reply": "2023-08-18T19:59:03.123881Z"
},
"origin_pos": 17,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BNLeNetScratch(d2l.Classifier):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(\n",
" nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Flatten(), nn.LazyLinear(120),\n",
" BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),\n",
" BatchNorm(84, num_dims=2), nn.Sigmoid(),\n",
" nn.LazyLinear(num_classes))"
]
},
{
"cell_type": "markdown",
"id": "ca66d52a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Train our network on the Fashion-MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "064cdd64",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:59:03.127886Z",
"iopub.status.busy": "2023-08-18T19:59:03.127595Z",
"iopub.status.idle": "2023-08-18T20:00:16.870229Z",
"shell.execute_reply": "2023-08-18T20:00:16.869283Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n",
"data = d2l.FashionMNIST(batch_size=128)\n",
"model = BNLeNetScratch(lr=0.1)\n",
"model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n",
"trainer.fit(model, data)"
]
},
{
"cell_type": "markdown",
"id": "5d788d3d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Have a look at the scale parameter `gamma`\n",
"and the shift parameter `beta`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4969fdc2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:00:16.875213Z",
"iopub.status.busy": "2023-08-18T20:00:16.874610Z",
"iopub.status.idle": "2023-08-18T20:00:16.971921Z",
"shell.execute_reply": "2023-08-18T20:00:16.970745Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([1.4334, 1.9905, 1.8584, 2.0740, 2.0522, 1.8877], device='cuda:0',\n",
" grad_fn=),\n",
" tensor([ 0.7354, -1.3538, -0.2567, -0.9991, -0.3028, 1.3125], device='cuda:0',\n",
" grad_fn=))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.net[1].gamma.reshape((-1,)), model.net[1].beta.reshape((-1,))"
]
},
{
"cell_type": "markdown",
"id": "a601daa9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Concise Implementation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ef2ab147",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:00:16.975625Z",
"iopub.status.busy": "2023-08-18T20:00:16.975018Z",
"iopub.status.idle": "2023-08-18T20:00:16.981373Z",
"shell.execute_reply": "2023-08-18T20:00:16.980550Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class BNLeNet(d2l.Classifier):\n",
" def __init__(self, lr=0.1, num_classes=10):\n",
" super().__init__()\n",
" self.save_hyperparameters()\n",
" self.net = nn.Sequential(\n",
" nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),\n",
" nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),\n",
" nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),\n",
" nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),\n",
" nn.Sigmoid(), nn.LazyLinear(num_classes))"
]
},
{
"cell_type": "markdown",
"id": "0a6d4742",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Use the same hyperparameters to train our model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0d6aaf49",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T20:00:16.984898Z",
"iopub.status.busy": "2023-08-18T20:00:16.984364Z",
"iopub.status.idle": "2023-08-18T20:01:21.082406Z",
"shell.execute_reply": "2023-08-18T20:01:21.081474Z"
},
"origin_pos": 33,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer = d2l.Trainer(max_epochs=10, num_gpus=1)\n",
"data = d2l.FashionMNIST(batch_size=128)\n",
"model = BNLeNet(lr=0.1)\n",
"model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)\n",
"trainer.fit(model, data)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}