{ "cells": [ { "cell_type": "markdown", "id": "89fd15cb", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/07_batchnorm.ipynb)\n", "\n", "# 🟡 Medium: Implement BatchNorm\n", "\n", "Implement **Batch Normalization** with both **training** and **inference** behavior.\n", "\n", "In training mode, use **batch statistics** and update running estimates:\n", "\n", "$$\\text{BN}(x) = \\gamma \\cdot \\frac{x - \\mu_B}{\\sqrt{\\sigma_B^2 + \\epsilon}} + \\beta$$\n", "\n", "where $\\mu_B$ and $\\sigma_B^2$ are the mean and variance computed **across the batch** (dim=0).\n", "\n", "In inference mode, use the provided **running mean/var** instead of current batch stats.\n", "\n", "### Signature\n", "```python\n", "def my_batch_norm(\n", " x: torch.Tensor,\n", " gamma: torch.Tensor,\n", " beta: torch.Tensor,\n", " running_mean: torch.Tensor,\n", " running_var: torch.Tensor,\n", " eps: float = 1e-5,\n", " momentum: float = 0.1,\n", " training: bool = True,\n", ") -> torch.Tensor:\n", " # x: (N, D) — normalize each feature across all samples in the batch\n", " # running_mean, running_var: updated in-place during training; used as-is during inference\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use `F.batch_norm`, `nn.BatchNorm1d`, etc.\n", "- Compute batch mean and variance over `dim=0` with `unbiased=False`\n", "- Update running stats like PyTorch: `running = (1 - momentum) * running + momentum * batch_stat`\n", "- Use `running_mean` / `running_var` for inference when `training=False`\n", "- Must support autograd w.r.t. `x`, `gamma`, `beta`(running statistics 应视作 buffer,而不是需要梯度的参数)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": null, "id": "d946ca79", "metadata": {}, "outputs": [], "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "def my_batch_norm(\n", " x,\n", " gamma,\n", " beta,\n", " running_mean,\n", " running_var,\n", " eps=1e-5,\n", " momentum=0.1,\n", " training=True,\n", "):\n", " pass # Replace this" ] }, { "cell_type": "code", "execution_count": null, "id": "26b93e71", "metadata": {}, "outputs": [], "source": [ "# 🧪 Debug\n", "x = torch.randn(8, 4)\n", "gamma = torch.ones(4)\n", "beta = torch.zeros(4)\n", "\n", "# Running stats typically live on the same device and shape as features\n", "running_mean = torch.zeros(4)\n", "running_var = torch.ones(4)\n", "\n", "# Training mode: uses batch stats and updates running_mean / running_var\n", "out_train = my_batch_norm(x, gamma, beta, running_mean, running_var, training=True)\n", "print(\"[Train] Output shape:\", out_train.shape)\n", "print(\"[Train] Column means:\", out_train.mean(dim=0)) # should be ~0\n", "print(\"[Train] Column stds: \", out_train.std(dim=0)) # should be ~1\n", "print(\"Updated running_mean:\", running_mean)\n", "print(\"Updated running_var:\", running_var)\n", "\n", "# Inference mode: uses running_mean / running_var only\n", "out_eval = my_batch_norm(x, gamma, beta, running_mean, running_var, training=False)\n", "print(\"[Eval] Output shape:\", out_eval.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ✅ SUBMIT\n", "from torch_judge import check\n", "check(\"batchnorm\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }