{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/12_linear_attention.ipynb)\n", "\n", "# 🔴 Hard: Linear Self-Attention\n", "\n", "Implement **Linear Attention** — O(S·D²) instead of O(S²·D), enabling efficient long-sequence processing.\n", "\n", "Replace softmax with a **kernel feature map** $\\phi$:\n", "\n", "$$\\text{LinearAttn}(Q,K,V) = \\frac{\\phi(Q) \\left(\\phi(K)^T V\\right)}{\\phi(Q) \\cdot \\sum \\phi(K)}$$\n", "\n", "### Feature map\n", "Use $\\phi(x) = \\text{elu}(x) + 1$ (ensures non-negative features).\n", "\n", "### Signature\n", "```python\n", "def linear_attention(Q, K, V):\n", " # Q: (B, S, D_k), K: (B, S, D_k), V: (B, S, D_v)\n", " # Returns: (B, S, D_v)\n", "```\n", "\n", "### Key insight\n", "Instead of computing the $S \\times S$ attention matrix, compute $\\phi(K)^T V$ first (a $D_k \\times D_v$ matrix), then multiply by $\\phi(Q)$.\n", "\n", "### Rules\n", "- Must use a feature map (NOT softmax)\n", "- Must be O(S·D²) — should run fast on long sequences\n", "- You **may** use `F.elu`" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import torch.nn.functional as F" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# ✏️ YOUR IMPLEMENTATION HERE\n", "\n", "def linear_attention(Q, K, V):\n", " pass # Replace this" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# 🧪 Debug\n", "Q = torch.randn(1, 8, 16)\n", "K = torch.randn(1, 8, 16)\n", "V = torch.randn(1, 8, 32)\n", "out = linear_attention(Q, K, V)\n", "print(\"Output shape:\", out.shape) # (1, 8, 32)\n", "print(\"Has NaN?\", torch.isnan(out).any().item())" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "from torch_judge import check\n", "check('linear_attention')" ], "outputs": [], "execution_count": null } ] }