{ "cells": [ { "cell_type": "markdown", "id": "2bfeea64", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/06_multihead_attention.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Multi-Head Attention\n", "\n", "Implement **Multi-Head Attention** from scratch โ€” the core building block of the Transformer.\n", "\n", "$$\\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_h) W^O$$\n", "$$\\text{head}_i = \\text{Attention}(Q W_i^Q,\\; K W_i^K,\\; V W_i^V)$$\n", "\n", "### Signature\n", "```python\n", "class MultiHeadAttention:\n", " def __init__(self, d_model: int, num_heads: int): ...\n", " def forward(self, Q, K, V) -> torch.Tensor: ...\n", "```\n", "\n", "### Requirements\n", "- Use `nn.Linear(d_model, d_model)` for `self.W_q`, `self.W_k`, `self.W_v`, `self.W_o`\n", "- `d_k = d_model // num_heads` per head\n", "- `forward(Q, K, V)`: Q is `(B, seq_q, d_model)`, K/V are `(B, seq_k, d_model)`\n", "- Must support **cross-attention** (`seq_q != seq_k`)\n", "- Do **NOT** use `torch.nn.MultiheadAttention`\n", "- You **may** use `torch.softmax` and `torch.matmul`\n", "\n", "### Steps\n", "1. Project: `q = self.W_q(Q)`, `k = self.W_k(K)`, `v = self.W_v(V)`\n", "2. Reshape to `(B, num_heads, seq, d_k)`\n", "3. Scaled dot-product attention per head\n", "4. Concat heads โ†’ `(B, seq_q, d_model)`\n", "5. Output projection: `self.W_o(concat)`" ] }, { "cell_type": "code", "execution_count": null, "id": "02a059c4", "metadata": {}, "outputs": [], "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2f0c22cb", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import math" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class MultiHeadAttention:\n", " def __init__(self, d_model: int, num_heads: int):\n", " pass # Initialize W_q, W_k, W_v, W_o\n", "\n", " def forward(self, Q, K, V):\n", " pass # Implement multi-head attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(0)\n", "mha = MultiHeadAttention(d_model=32, num_heads=4)\n", "print(\"W_q type:\", type(mha.W_q)) # should be nn.Linear\n", "print(\"W_q.weight shape:\", mha.W_q.weight.shape) # (32, 32)\n", "\n", "x = torch.randn(2, 6, 32)\n", "out = mha.forward(x, x, x)\n", "print(\"Output shape:\", out.shape) # (2, 6, 32)\n", "\n", "# Cross-attention\n", "Q = torch.randn(1, 3, 32)\n", "K = torch.randn(1, 7, 32)\n", "V = torch.randn(1, 7, 32)\n", "out2 = mha.forward(Q, K, V)\n", "print(\"Cross-attn shape:\", out2.shape) # (1, 3, 32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check(\"mha\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }