{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/23_cross_attention.ipynb)\n", "\n", "# ๐ŸŸ  Medium: Multi-Head Cross-Attention\n", "\n", "Implement **multi-head cross-attention** (encoder-decoder attention).\n", "\n", "### Signature\n", "```python\n", "class MultiHeadCrossAttention(nn.Module):\n", " def __init__(self, d_model: int, num_heads: int): ...\n", " def forward(self, x_q: Tensor, x_kv: Tensor) -> Tensor:\n", " # x_q: (B, S_q, D) โ€” decoder queries\n", " # x_kv: (B, S_kv, D) โ€” encoder keys/values\n", "```\n", "\n", "### Key Differences from Self-Attention\n", "- Q comes from the decoder, K and V come from the encoder\n", "- No causal mask (all encoder positions visible)" ], "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import math" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "class MultiHeadCrossAttention(nn.Module):\n", " def __init__(self, d_model, num_heads):\n", " super().__init__()\n", " pass # W_q, W_k, W_v, W_o\n", "\n", " def forward(self, x_q, x_kv):\n", " pass # Q from x_q, K/V from x_kv, no causal mask" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "attn = MultiHeadCrossAttention(64, 4)\n", "x_q = torch.randn(2, 6, 64)\n", "x_kv = torch.randn(2, 10, 64)\n", "print('Output:', attn(x_q, x_kv).shape)" ], "execution_count": null }, { "cell_type": "code", "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check('cross_attention')" ], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }