{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/25_flash_attention.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Flash Attention (Tiled)\n", "\n", "Implement **tiled attention with online softmax** โ€” the core idea behind Flash Attention.\n", "\n", "### Signature\n", "```python\n", "def flash_attention(Q, K, V, block_size=32) -> Tensor:\n", " # Q, K, V: (B, S, D)\n", " # Returns: (B, S, D) โ€” same as standard attention\n", "```\n", "\n", "### Key Insight\n", "Instead of materializing the full Sร—S attention matrix, process in blocks:\n", "1. For each Q-block, iterate over K/V blocks\n", "2. Use **online softmax**: track running `max` and `sum`\n", "3. Rescale accumulator when max changes: `acc *= exp(old_max - new_max)`\n", "4. Final: `output = acc / row_sum`\n", "\n", "Must give **identical** results to standard softmax attention." ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def flash_attention(Q, K, V, block_size=32):\n", " # Process Q in blocks, iterate K/V blocks with online softmax\n", " pass" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "import math\n", "Q, K, V = torch.randn(1, 8, 4), torch.randn(1, 8, 4), torch.randn(1, 8, 4)\n", "out = flash_attention(Q, K, V, block_size=4)\n", "scores = torch.bmm(Q, K.transpose(1,2)) / math.sqrt(4)\n", "ref = torch.bmm(torch.softmax(scores, dim=-1), V)\n", "print('Match:', torch.allclose(out, ref, atol=1e-4))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('flash_attention')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }