{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/04_layernorm.ipynb)\n", "\n", "# ๐ŸŸก Medium: Implement LayerNorm\n", "\n", "Implement **Layer Normalization** from scratch.\n", "\n", "$$\\text{LayerNorm}(x) = \\gamma \\cdot \\frac{x - \\mu}{\\sqrt{\\sigma^2 + \\epsilon}} + \\beta$$\n", "\n", "where $\\mu$ and $\\sigma^2$ are computed over the **last dimension**.\n", "\n", "### Signature\n", "```python\n", "def my_layer_norm(\n", " x: torch.Tensor, # input\n", " gamma: torch.Tensor, # scale (same size as last dim)\n", " beta: torch.Tensor, # shift (same size as last dim)\n", " eps: float = 1e-5\n", ") -> torch.Tensor:\n", " ...\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use `F.layer_norm` or `torch.nn.LayerNorm`\n", "- Normalize over the last dimension only\n", "- Must support autograd" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def my_layer_norm(x, gamma, beta, eps=1e-5):\n", " pass # Replace this" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "x = torch.randn(2, 8)\n", "gamma = torch.ones(8)\n", "beta = torch.zeros(8)\n", "\n", "out = my_layer_norm(x, gamma, beta)\n", "ref = torch.nn.functional.layer_norm(x, [8], gamma, beta)\n", "\n", "print(\"Your output mean:\", out.mean(dim=-1)) # should be ~0\n", "print(\"Your output std: \", out.std(dim=-1)) # should be ~1\n", "print(\"Match ref? \", torch.allclose(out, ref, atol=1e-4))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check(\"layernorm\")" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }