{
"cells": [
{
"cell_type": "markdown",
"id": "f6a0ffe1",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Attention Scoring Functions\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8e33a108",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:45.433072Z",
"iopub.status.busy": "2023-08-18T19:43:45.432523Z",
"iopub.status.idle": "2023-08-18T19:43:48.504425Z",
"shell.execute_reply": "2023-08-18T19:43:48.503548Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9440ec4c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Dot Product Attention\n",
"Masked Softmax Operation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "080c4919",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.508521Z",
"iopub.status.busy": "2023-08-18T19:43:48.507880Z",
"iopub.status.idle": "2023-08-18T19:43:48.515032Z",
"shell.execute_reply": "2023-08-18T19:43:48.514260Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def masked_softmax(X, valid_lens): \n",
" \"\"\"Perform softmax operation by masking elements on the last axis.\"\"\"\n",
" def _sequence_mask(X, valid_len, value=0):\n",
" maxlen = X.size(1)\n",
" mask = torch.arange((maxlen), dtype=torch.float32,\n",
" device=X.device)[None, :] < valid_len[:, None]\n",
" X[~mask] = value\n",
" return X\n",
"\n",
" if valid_lens is None:\n",
" return nn.functional.softmax(X, dim=-1)\n",
" else:\n",
" shape = X.shape\n",
" if valid_lens.dim() == 1:\n",
" valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n",
" else:\n",
" valid_lens = valid_lens.reshape(-1)\n",
" X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)\n",
" return nn.functional.softmax(X.reshape(shape), dim=-1)"
]
},
{
"cell_type": "markdown",
"id": "ca7b144c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Illustrate how this function works"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b0fb493b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.518456Z",
"iopub.status.busy": "2023-08-18T19:43:48.517778Z",
"iopub.status.idle": "2023-08-18T19:43:48.554108Z",
"shell.execute_reply": "2023-08-18T19:43:48.553283Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.4448, 0.5552, 0.0000, 0.0000],\n",
" [0.4032, 0.5968, 0.0000, 0.0000]],\n",
"\n",
" [[0.2795, 0.2805, 0.4400, 0.0000],\n",
" [0.2798, 0.3092, 0.4110, 0.0000]]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0eff10c9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.557828Z",
"iopub.status.busy": "2023-08-18T19:43:48.557262Z",
"iopub.status.idle": "2023-08-18T19:43:48.564098Z",
"shell.execute_reply": "2023-08-18T19:43:48.563239Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4109, 0.2794, 0.3097, 0.0000]],\n",
"\n",
" [[0.3960, 0.6040, 0.0000, 0.0000],\n",
" [0.2557, 0.1833, 0.2420, 0.3190]]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1d592456",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.567605Z",
"iopub.status.busy": "2023-08-18T19:43:48.567037Z",
"iopub.status.idle": "2023-08-18T19:43:48.572146Z",
"shell.execute_reply": "2023-08-18T19:43:48.571131Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"Q = torch.ones((2, 3, 4))\n",
"K = torch.ones((2, 4, 6))\n",
"d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))"
]
},
{
"cell_type": "markdown",
"id": "935cb045",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Scaled Dot Product Attention"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "33207d5f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.575743Z",
"iopub.status.busy": "2023-08-18T19:43:48.575036Z",
"iopub.status.idle": "2023-08-18T19:43:48.581055Z",
"shell.execute_reply": "2023-08-18T19:43:48.580209Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class DotProductAttention(nn.Module): \n",
" \"\"\"Scaled dot product attention.\"\"\"\n",
" def __init__(self, dropout):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, queries, keys, values, valid_lens=None):\n",
" d = queries.shape[-1]\n",
" scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)\n",
" self.attention_weights = masked_softmax(scores, valid_lens)\n",
" return torch.bmm(self.dropout(self.attention_weights), values)"
]
},
{
"cell_type": "markdown",
"id": "c63fab7b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Illustrate how the `DotProductAttention` class works"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f40e370d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.594461Z",
"iopub.status.busy": "2023-08-18T19:43:48.593898Z",
"iopub.status.idle": "2023-08-18T19:43:48.969221Z",
"shell.execute_reply": "2023-08-18T19:43:48.968308Z"
},
"origin_pos": 37,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"queries = torch.normal(0, 1, (2, 1, 2))\n",
"keys = torch.normal(0, 1, (2, 10, 2))\n",
"values = torch.normal(0, 1, (2, 10, 4))\n",
"valid_lens = torch.tensor([2, 6])\n",
"\n",
"attention = DotProductAttention(dropout=0.5)\n",
"attention.eval()\n",
"d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))\n",
"\n",
"d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
" xlabel='Keys', ylabel='Queries')"
]
},
{
"cell_type": "markdown",
"id": "ae2e4399",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Additive Attention"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3a2e6dee",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.973108Z",
"iopub.status.busy": "2023-08-18T19:43:48.972388Z",
"iopub.status.idle": "2023-08-18T19:43:48.979819Z",
"shell.execute_reply": "2023-08-18T19:43:48.978914Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class AdditiveAttention(nn.Module): \n",
" \"\"\"Additive attention.\"\"\"\n",
" def __init__(self, num_hiddens, dropout, **kwargs):\n",
" super(AdditiveAttention, self).__init__(**kwargs)\n",
" self.W_k = nn.LazyLinear(num_hiddens, bias=False)\n",
" self.W_q = nn.LazyLinear(num_hiddens, bias=False)\n",
" self.w_v = nn.LazyLinear(1, bias=False)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, queries, keys, values, valid_lens):\n",
" queries, keys = self.W_q(queries), self.W_k(keys)\n",
" features = queries.unsqueeze(2) + keys.unsqueeze(1)\n",
" features = torch.tanh(features)\n",
" scores = self.w_v(features).squeeze(-1)\n",
" self.attention_weights = masked_softmax(scores, valid_lens)\n",
" return torch.bmm(self.dropout(self.attention_weights), values)"
]
},
{
"cell_type": "markdown",
"id": "b338adaa",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"See how `AdditiveAttention` works"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bf7a330b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.996815Z",
"iopub.status.busy": "2023-08-18T19:43:48.996248Z",
"iopub.status.idle": "2023-08-18T19:43:49.212301Z",
"shell.execute_reply": "2023-08-18T19:43:49.211395Z"
},
"origin_pos": 50,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"queries = torch.normal(0, 1, (2, 1, 20))\n",
"\n",
"attention = AdditiveAttention(num_hiddens=8, dropout=0.1)\n",
"attention.eval()\n",
"d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))\n",
"\n",
"d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
" xlabel='Keys', ylabel='Queries')"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}