{ "cells": [ { "cell_type": "markdown", "id": "f6a0ffe1", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Attention Scoring Functions\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "8e33a108", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:45.433072Z", "iopub.status.busy": "2023-08-18T19:43:45.432523Z", "iopub.status.idle": "2023-08-18T19:43:48.504425Z", "shell.execute_reply": "2023-08-18T19:43:48.503548Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "9440ec4c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Dot Product Attention\n", "Masked Softmax Operation" ] }, { "cell_type": "code", "execution_count": 2, "id": "080c4919", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.508521Z", "iopub.status.busy": "2023-08-18T19:43:48.507880Z", "iopub.status.idle": "2023-08-18T19:43:48.515032Z", "shell.execute_reply": "2023-08-18T19:43:48.514260Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def masked_softmax(X, valid_lens): \n", " \"\"\"Perform softmax operation by masking elements on the last axis.\"\"\"\n", " def _sequence_mask(X, valid_len, value=0):\n", " maxlen = X.size(1)\n", " mask = torch.arange((maxlen), dtype=torch.float32,\n", " device=X.device)[None, :] < valid_len[:, None]\n", " X[~mask] = value\n", " return X\n", "\n", " if valid_lens is None:\n", " return nn.functional.softmax(X, dim=-1)\n", " else:\n", " shape = X.shape\n", " if valid_lens.dim() == 1:\n", " valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n", " else:\n", " valid_lens = valid_lens.reshape(-1)\n", " X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)\n", " return nn.functional.softmax(X.reshape(shape), dim=-1)" ] }, { "cell_type": "markdown", "id": "ca7b144c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Illustrate how this function works" ] }, { "cell_type": "code", "execution_count": 3, "id": "b0fb493b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.518456Z", "iopub.status.busy": "2023-08-18T19:43:48.517778Z", "iopub.status.idle": "2023-08-18T19:43:48.554108Z", "shell.execute_reply": "2023-08-18T19:43:48.553283Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[0.4448, 0.5552, 0.0000, 0.0000],\n", " [0.4032, 0.5968, 0.0000, 0.0000]],\n", "\n", " [[0.2795, 0.2805, 0.4400, 0.0000],\n", " [0.2798, 0.3092, 0.4110, 0.0000]]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))" ] }, { "cell_type": "code", "execution_count": 4, "id": "0eff10c9", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.557828Z", "iopub.status.busy": "2023-08-18T19:43:48.557262Z", "iopub.status.idle": "2023-08-18T19:43:48.564098Z", "shell.execute_reply": "2023-08-18T19:43:48.563239Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", " [0.4109, 0.2794, 0.3097, 0.0000]],\n", "\n", " [[0.3960, 0.6040, 0.0000, 0.0000],\n", " [0.2557, 0.1833, 0.2420, 0.3190]]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))" ] }, { "cell_type": "code", "execution_count": 5, "id": "1d592456", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.567605Z", "iopub.status.busy": "2023-08-18T19:43:48.567037Z", "iopub.status.idle": "2023-08-18T19:43:48.572146Z", "shell.execute_reply": "2023-08-18T19:43:48.571131Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "Q = torch.ones((2, 3, 4))\n", "K = torch.ones((2, 4, 6))\n", "d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))" ] }, { "cell_type": "markdown", "id": "935cb045", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Scaled Dot Product Attention" ] }, { "cell_type": "code", "execution_count": 6, "id": "33207d5f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.575743Z", "iopub.status.busy": "2023-08-18T19:43:48.575036Z", "iopub.status.idle": "2023-08-18T19:43:48.581055Z", "shell.execute_reply": "2023-08-18T19:43:48.580209Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class DotProductAttention(nn.Module): \n", " \"\"\"Scaled dot product attention.\"\"\"\n", " def __init__(self, dropout):\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, queries, keys, values, valid_lens=None):\n", " d = queries.shape[-1]\n", " scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", " return torch.bmm(self.dropout(self.attention_weights), values)" ] }, { "cell_type": "markdown", "id": "c63fab7b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Illustrate how the `DotProductAttention` class works" ] }, { "cell_type": "code", "execution_count": 8, "id": "f40e370d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.594461Z", "iopub.status.busy": "2023-08-18T19:43:48.593898Z", "iopub.status.idle": "2023-08-18T19:43:48.969221Z", "shell.execute_reply": "2023-08-18T19:43:48.968308Z" }, "origin_pos": 37, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:43:48.906167\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "queries = torch.normal(0, 1, (2, 1, 2))\n", "keys = torch.normal(0, 1, (2, 10, 2))\n", "values = torch.normal(0, 1, (2, 10, 4))\n", "valid_lens = torch.tensor([2, 6])\n", "\n", "attention = DotProductAttention(dropout=0.5)\n", "attention.eval()\n", "d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))\n", "\n", "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", " xlabel='Keys', ylabel='Queries')" ] }, { "cell_type": "markdown", "id": "ae2e4399", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Additive Attention" ] }, { "cell_type": "code", "execution_count": 9, "id": "3a2e6dee", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.973108Z", "iopub.status.busy": "2023-08-18T19:43:48.972388Z", "iopub.status.idle": "2023-08-18T19:43:48.979819Z", "shell.execute_reply": "2023-08-18T19:43:48.978914Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class AdditiveAttention(nn.Module): \n", " \"\"\"Additive attention.\"\"\"\n", " def __init__(self, num_hiddens, dropout, **kwargs):\n", " super(AdditiveAttention, self).__init__(**kwargs)\n", " self.W_k = nn.LazyLinear(num_hiddens, bias=False)\n", " self.W_q = nn.LazyLinear(num_hiddens, bias=False)\n", " self.w_v = nn.LazyLinear(1, bias=False)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, queries, keys, values, valid_lens):\n", " queries, keys = self.W_q(queries), self.W_k(keys)\n", " features = queries.unsqueeze(2) + keys.unsqueeze(1)\n", " features = torch.tanh(features)\n", " scores = self.w_v(features).squeeze(-1)\n", " self.attention_weights = masked_softmax(scores, valid_lens)\n", " return torch.bmm(self.dropout(self.attention_weights), values)" ] }, { "cell_type": "markdown", "id": "b338adaa", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "See how `AdditiveAttention` works" ] }, { "cell_type": "code", "execution_count": 11, "id": "bf7a330b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:43:48.996815Z", "iopub.status.busy": "2023-08-18T19:43:48.996248Z", "iopub.status.idle": "2023-08-18T19:43:49.212301Z", "shell.execute_reply": "2023-08-18T19:43:49.211395Z" }, "origin_pos": 50, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:43:49.170875\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "queries = torch.normal(0, 1, (2, 1, 20))\n", "\n", "attention = AdditiveAttention(num_hiddens=8, dropout=0.1)\n", "attention.eval()\n", "d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))\n", "\n", "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n", " xlabel='Keys', ylabel='Queries')" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }