{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/09_causal_attention.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Causal Self-Attention\n", "\n", "Implement **causal (masked) self-attention** โ€” the attention used in GPT-style decoders.\n", "\n", "Same as softmax attention, but each position can **only attend to itself and earlier positions** (no peeking at future tokens).\n", "\n", "$$\\text{scores}_{ij} = \\begin{cases} \\frac{Q_i \\cdot K_j}{\\sqrt{d_k}} & \\text{if } j \\le i \\\\ -\\infty & \\text{if } j > i \\end{cases}$$\n", "\n", "### Signature\n", "```python\n", "def causal_attention(Q, K, V):\n", " # Q, K, V: (batch, seq, d) โ†’ output: (batch, seq, d_v)\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use `F.scaled_dot_product_attention`\n", "- Position $i$ can only attend to positions $\\le i$\n", "- You **may** use `torch.softmax`, `torch.bmm`, `torch.triu`" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import math" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def causal_attention(Q, K, V):\n", " pass # Replace this" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(0)\n", "Q = torch.randn(1, 4, 8)\n", "K = torch.randn(1, 4, 8)\n", "V = torch.randn(1, 4, 8)\n", "out = causal_attention(Q, K, V)\n", "print(\"Output shape:\", out.shape) # (1, 4, 8)\n", "print(\"Pos 0 == V[0]?\", torch.allclose(out[:, 0], V[:, 0], atol=1e-5)) # should be True" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "from torch_judge import check\n", "check('causal_attention')" ], "outputs": [], "execution_count": null } ] }