{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/24_rope.ipynb)\n", "\n", "# 🔴 Hard: Rotary Position Embedding (RoPE)\n", "\n", "Implement **RoPE** — the position encoding used in LLaMA, GPT-NeoX, and most modern LLMs.\n", "\n", "### Signature\n", "```python\n", "def apply_rope(q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:\n", " # q, k: (B, S, D) where D is even\n", " # Returns rotated (q, k) with same shape\n", "```\n", "\n", "### Key Idea\n", "Split each vector into consecutive pairs. Rotate each pair by `θ = pos / 10000^(2i/D)`:\n", "```\n", "[x_0, x_1] → [x_0*cosθ - x_1*sinθ, x_0*sinθ + x_1*cosθ]\n", "```\n", "This makes `dot(q_rot[i], k_rot[j])` depend only on `i - j` (relative position)." ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "def apply_rope(q, k):\n", " # 1. Compute position angles\n", " # 2. Split into even/odd pairs\n", " # 3. Apply rotation\n", " pass" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# 🧪 Debug\n", "q = torch.randn(1, 8, 16)\n", "k = torch.randn(1, 8, 16)\n", "qr, kr = apply_rope(q, k)\n", "print('Shape preserved:', qr.shape == q.shape)\n", "print('Norm preserved:', torch.allclose(q.norm(dim=-1), qr.norm(dim=-1), atol=1e-4))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ✅ SUBMIT\n", "from torch_judge import check\n", "check('rope')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }