{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/10_gqa.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Grouped Query Attention (GQA)\n", "\n", "Implement **Grouped Query Attention** โ€” used in LLaMA 2, Mistral, etc. to reduce KV cache size.\n", "\n", "Like MHA, but with **fewer KV heads** than Q heads. Each group of Q heads shares the same K/V head.\n", "\n", "### Signature\n", "```python\n", "class GroupQueryAttention:\n", " def __init__(self, d_model: int, num_heads: int, num_kv_heads: int): ...\n", " def forward(self, x) -> torch.Tensor: # self-attention\n", "```\n", "\n", "### Requirements\n", "- `self.W_q`: `nn.Linear(d_model, d_model)` โ€” full Q projection\n", "- `self.W_k`: `nn.Linear(d_model, num_kv_heads * d_k)` โ€” reduced K projection\n", "- `self.W_v`: `nn.Linear(d_model, num_kv_heads * d_k)` โ€” reduced V projection\n", "- `self.W_o`: `nn.Linear(d_model, d_model)` โ€” output projection\n", "- `d_k = d_model // num_heads`\n", "- Expand KV heads with `repeat_interleave` to match Q heads\n", "- When `num_kv_heads == num_heads`, should behave like standard MHA" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import torch.nn as nn\n", "import math" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class GroupQueryAttention:\n", " def __init__(self, d_model, num_heads, num_kv_heads):\n", " pass # Initialize projections\n", "\n", " def forward(self, x):\n", " pass # Self-attention with grouped KV" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(0)\n", "gqa = GroupQueryAttention(d_model=32, num_heads=8, num_kv_heads=2)\n", "print(\"W_q shape:\", gqa.W_q.weight.shape) # (32, 32)\n", "print(\"W_k shape:\", gqa.W_k.weight.shape) # (8, 32) โ€” only 2 KV heads * d_k=4\n", "\n", "x = torch.randn(2, 6, 32)\n", "out = gqa.forward(x)\n", "print(\"Output shape:\", out.shape) # (2, 6, 32)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "from torch_judge import check\n", "check('gqa')" ], "outputs": [], "execution_count": null } ] }