{
"cells": [
{
"cell_type": "markdown",
"id": "f6a0ffe1",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Attention Scoring Functions\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8e33a108",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:45.433072Z",
"iopub.status.busy": "2023-08-18T19:43:45.432523Z",
"iopub.status.idle": "2023-08-18T19:43:48.504425Z",
"shell.execute_reply": "2023-08-18T19:43:48.503548Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "9440ec4c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Dot Product Attention\n",
"Masked Softmax Operation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "080c4919",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.508521Z",
"iopub.status.busy": "2023-08-18T19:43:48.507880Z",
"iopub.status.idle": "2023-08-18T19:43:48.515032Z",
"shell.execute_reply": "2023-08-18T19:43:48.514260Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def masked_softmax(X, valid_lens): \n",
" \"\"\"Perform softmax operation by masking elements on the last axis.\"\"\"\n",
" def _sequence_mask(X, valid_len, value=0):\n",
" maxlen = X.size(1)\n",
" mask = torch.arange((maxlen), dtype=torch.float32,\n",
" device=X.device)[None, :] < valid_len[:, None]\n",
" X[~mask] = value\n",
" return X\n",
"\n",
" if valid_lens is None:\n",
" return nn.functional.softmax(X, dim=-1)\n",
" else:\n",
" shape = X.shape\n",
" if valid_lens.dim() == 1:\n",
" valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n",
" else:\n",
" valid_lens = valid_lens.reshape(-1)\n",
" X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)\n",
" return nn.functional.softmax(X.reshape(shape), dim=-1)"
]
},
{
"cell_type": "markdown",
"id": "ca7b144c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Illustrate how this function works"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b0fb493b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.518456Z",
"iopub.status.busy": "2023-08-18T19:43:48.517778Z",
"iopub.status.idle": "2023-08-18T19:43:48.554108Z",
"shell.execute_reply": "2023-08-18T19:43:48.553283Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.4448, 0.5552, 0.0000, 0.0000],\n",
" [0.4032, 0.5968, 0.0000, 0.0000]],\n",
"\n",
" [[0.2795, 0.2805, 0.4400, 0.0000],\n",
" [0.2798, 0.3092, 0.4110, 0.0000]]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0eff10c9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.557828Z",
"iopub.status.busy": "2023-08-18T19:43:48.557262Z",
"iopub.status.idle": "2023-08-18T19:43:48.564098Z",
"shell.execute_reply": "2023-08-18T19:43:48.563239Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4109, 0.2794, 0.3097, 0.0000]],\n",
"\n",
" [[0.3960, 0.6040, 0.0000, 0.0000],\n",
" [0.2557, 0.1833, 0.2420, 0.3190]]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1d592456",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.567605Z",
"iopub.status.busy": "2023-08-18T19:43:48.567037Z",
"iopub.status.idle": "2023-08-18T19:43:48.572146Z",
"shell.execute_reply": "2023-08-18T19:43:48.571131Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"Q = torch.ones((2, 3, 4))\n",
"K = torch.ones((2, 4, 6))\n",
"d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))"
]
},
{
"cell_type": "markdown",
"id": "935cb045",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Scaled Dot Product Attention"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "33207d5f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.575743Z",
"iopub.status.busy": "2023-08-18T19:43:48.575036Z",
"iopub.status.idle": "2023-08-18T19:43:48.581055Z",
"shell.execute_reply": "2023-08-18T19:43:48.580209Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class DotProductAttention(nn.Module): \n",
" \"\"\"Scaled dot product attention.\"\"\"\n",
" def __init__(self, dropout):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, queries, keys, values, valid_lens=None):\n",
" d = queries.shape[-1]\n",
" scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)\n",
" self.attention_weights = masked_softmax(scores, valid_lens)\n",
" return torch.bmm(self.dropout(self.attention_weights), values)"
]
},
{
"cell_type": "markdown",
"id": "c63fab7b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Illustrate how the `DotProductAttention` class works"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f40e370d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:43:48.594461Z",
"iopub.status.busy": "2023-08-18T19:43:48.593898Z",
"iopub.status.idle": "2023-08-18T19:43:48.969221Z",
"shell.execute_reply": "2023-08-18T19:43:48.968308Z"
},
"origin_pos": 37,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"