{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/08_rmsnorm.ipynb)\n", "\n", "# ๐ŸŸก Medium: Implement RMSNorm\n", "\n", "Implement **Root Mean Square Layer Normalization** โ€” the normalization used in LLaMA, Gemma, etc.\n", "\n", "$$\\text{RMSNorm}(x) = \\frac{x}{\\text{RMS}(x)} \\cdot w, \\quad \\text{RMS}(x) = \\sqrt{\\frac{1}{d}\\sum x_i^2 + \\epsilon}$$\n", "\n", "### Signature\n", "```python\n", "def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:\n", " # Normalize over the last dimension. No mean subtraction (unlike LayerNorm).\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use any built-in norm layers\n", "- Normalize over `dim=-1`\n", "- Must support autograd" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "import torch" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def rms_norm(x, weight, eps=1e-6):\n", " pass # Replace this" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# ๐Ÿงช Debug\n", "x = torch.randn(2, 8)\n", "w = torch.ones(8)\n", "out = rms_norm(x, w)\n", "print(\"Output shape:\", out.shape)\n", "print(\"RMS of output:\", out.pow(2).mean(dim=-1).sqrt()) # should be ~1" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "from torch_judge import check\n", "check('rmsnorm')" ], "outputs": [], "execution_count": null } ] }