{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/templates/05_attention.ipynb)\n", "\n", "# ๐Ÿ”ด Hard: Softmax Attention\n", "\n", "Implement the core attention mechanism used in Transformers.\n", "\n", "$$\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V$$\n", "\n", "### Signature\n", "```python\n", "def scaled_dot_product_attention(\n", " Q: torch.Tensor, # (batch, seq_q, d_k)\n", " K: torch.Tensor, # (batch, seq_k, d_k)\n", " V: torch.Tensor, # (batch, seq_k, d_v)\n", ") -> torch.Tensor: # (batch, seq_q, d_v)\n", " ...\n", "```\n", "\n", "### Rules\n", "- Do **NOT** use `F.scaled_dot_product_attention`\n", "- You **may** use `torch.softmax` and `torch.bmm`\n", "- Must support autograd\n", "- Must handle cross-attention (seq_q โ‰  seq_k)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Install torch-judge in Colab (no-op in JupyterLab/Docker)\n", "try:\n", " import google.colab\n", " get_ipython().run_line_magic('pip', 'install -q torch-judge')\n", "except ImportError:\n", " pass\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import math" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ๏ธ YOUR IMPLEMENTATION HERE\n", "\n", "def scaled_dot_product_attention(Q, K, V):\n", " pass # Replace this" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ๐Ÿงช Debug\n", "torch.manual_seed(42)\n", "Q = torch.randn(2, 4, 8)\n", "K = torch.randn(2, 4, 8)\n", "V = torch.randn(2, 4, 8)\n", "\n", "out = scaled_dot_product_attention(Q, K, V)\n", "print(\"Output shape:\", out.shape) # should be (2, 4, 8)\n", "print(\"Has NaN? \", torch.isnan(out).any().item()) # should be False\n", "print(\"Has Inf? \", torch.isinf(out).any().item()) # should be False\n", "\n", "# Cross-attention: seq_q != seq_k\n", "Q2 = torch.randn(1, 3, 16)\n", "K2 = torch.randn(1, 5, 16)\n", "V2 = torch.randn(1, 5, 32)\n", "out2 = scaled_dot_product_attention(Q2, K2, V2)\n", "print(\"Cross-attn shape:\", out2.shape) # should be (1, 3, 32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# โœ… SUBMIT\n", "from torch_judge import check\n", "check(\"attention\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }