{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/14_kv_cache.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: KV Cache Attention\n", "\n", "Implement **multi-head attention with KV caching** for efficient autoregressive generation.\n", "\n", "During LLM inference, recomputing all key/value projections at every step is wasteful.\n", "A **KV cache** stores previously computed K and V tensors so only the new token(s) need projection.\n", "\n", "### Signature\n", "```python\n", "class KVCacheAttention(nn.Module):\n", " def __init__(self, d_model: int, num_heads: int): ...\n", " def forward(self, x: torch.Tensor, cache=None) -> tuple[torch.Tensor, tuple]:\n", " # x: (B, S_new, D) โ€” new tokens\n", " # cache: None or (K_past, V_past) each (B, num_heads, S_past, d_k)\n", " # Returns: (output, (K_all, V_all))\n", "```\n", "\n", "### Requirements\n", "- Inherit from `nn.Module`\n", "- `self.W_q`, `self.W_k`, `self.W_v`, `self.W_o`: `nn.Linear` projections\n", "- When `cache=None` (prefill): apply **causal mask**, return all K/V as cache\n", "- When `cache` provided (decode): concat new K/V with cached, no causal mask needed for single-token decode\n", "- Incremental decode must produce **identical** results to full forward pass\n", "\n", "### Key Idea\n", "```\n", "Prefill: [t0 t1 t2 t3] โ†’ full causal attention โ†’ cache = (K_{0:3}, V_{0:3})\n", "Decode: [t4] โ†’ Q=t4, K/V=cache+t4 โ†’ cache = (K_{0:4}, V_{0:4})\n", "Decode: [t5] โ†’ Q=t5, K/V=cache+t5 โ†’ cache = (K_{0:5}, V_{0:5})\n", "```" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class KVCacheAttention(nn.Module):\n", " def __init__(self, d_model, num_heads):\n", " super().__init__()\n", " pass # Initialize W_q, W_k, W_v, W_o\n", "\n", " def forward(self, x, cache=None):\n", " # 1. Project Q, K, V from x\n", " # 2. Reshape to multi-head: (B, num_heads, S, d_k)\n", " # 3. If cache exists, concat new K/V with cached K/V\n", " # 4. Compute attention (causal mask needed during prefill)\n", " # 5. Return (output, (K_all, V_all))\n", " pass" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(0)\n", "attn = KVCacheAttention(d_model=64, num_heads=4)\n", "x = torch.randn(1, 6, 64)\n", "\n", "# Full forward\n", "full_out, _ = attn(x)\n", "print(\"Full output shape:\", full_out.shape) # (1, 6, 64)\n", "\n", "# Incremental: prefill 4, decode 1, decode 1\n", "out1, cache = attn(x[:, :4])\n", "print(\"Cache K shape:\", cache[0].shape) # (1, 4, 4, 16)\n", "out2, cache = attn(x[:, 4:5], cache=cache)\n", "out3, cache = attn(x[:, 5:6], cache=cache)\n", "inc_out = torch.cat([out1, out2, out3], dim=1)\n", "print(\"Match:\", torch.allclose(full_out, inc_out, atol=1e-5))" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('kv_cache')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }