{
"cells": [
{
"cell_type": "markdown",
"id": "248bcc01",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Self-Attention and Positional Encoding\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b2969e34",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:32.804452Z",
"iopub.status.busy": "2023-08-18T19:30:32.803811Z",
"iopub.status.idle": "2023-08-18T19:30:35.929844Z",
"shell.execute_reply": "2023-08-18T19:30:35.926598Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "39ee3522",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Self-Attention"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "13743b61",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.935527Z",
"iopub.status.busy": "2023-08-18T19:30:35.934433Z",
"iopub.status.idle": "2023-08-18T19:30:35.974177Z",
"shell.execute_reply": "2023-08-18T19:30:35.973091Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"num_hiddens, num_heads = 100, 5\n",
"attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)\n",
"batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n",
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
"d2l.check_shape(attention(X, X, X, valid_lens),\n",
" (batch_size, num_queries, num_hiddens))"
]
},
{
"cell_type": "markdown",
"id": "525745e4",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Positional Encoding"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3eb1b5ef",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.979909Z",
"iopub.status.busy": "2023-08-18T19:30:35.978770Z",
"iopub.status.idle": "2023-08-18T19:30:35.987465Z",
"shell.execute_reply": "2023-08-18T19:30:35.986155Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module): \n",
" \"\"\"Positional encoding.\"\"\"\n",
" def __init__(self, num_hiddens, dropout, max_len=1000):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.P = torch.zeros((1, max_len, num_hiddens))\n",
" X = torch.arange(max_len, dtype=torch.float32).reshape(\n",
" -1, 1) / torch.pow(10000, torch.arange(\n",
" 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n",
" self.P[:, :, 0::2] = torch.sin(X)\n",
" self.P[:, :, 1::2] = torch.cos(X)\n",
"\n",
" def forward(self, X):\n",
" X = X + self.P[:, :X.shape[1], :].to(X.device)\n",
" return self.dropout(X)"
]
},
{
"cell_type": "markdown",
"id": "a0548d5d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Rows correspond to positions within a sequence\n",
"and columns represent different positional encoding dimensions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "51320f4e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.991251Z",
"iopub.status.busy": "2023-08-18T19:30:35.990632Z",
"iopub.status.idle": "2023-08-18T19:30:36.368109Z",
"shell.execute_reply": "2023-08-18T19:30:36.366973Z"
},
"origin_pos": 21,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"encoding_dim, num_steps = 32, 60\n",
"pos_encoding = PositionalEncoding(encoding_dim, 0)\n",
"X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n",
"P = pos_encoding.P[:, :X.shape[1], :]\n",
"d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n",
" figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])"
]
},
{
"cell_type": "markdown",
"id": "12388348",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The binary representations"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6f42d89b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:36.373921Z",
"iopub.status.busy": "2023-08-18T19:30:36.373258Z",
"iopub.status.idle": "2023-08-18T19:30:36.380089Z",
"shell.execute_reply": "2023-08-18T19:30:36.378862Z"
},
"origin_pos": 25,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 in binary is 000\n",
"1 in binary is 001\n",
"2 in binary is 010\n",
"3 in binary is 011\n",
"4 in binary is 100\n",
"5 in binary is 101\n",
"6 in binary is 110\n",
"7 in binary is 111\n"
]
}
],
"source": [
"for i in range(8):\n",
" print(f'{i} in binary is {i:>03b}')"
]
},
{
"cell_type": "markdown",
"id": "e137cf32",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The positional encoding decreases\n",
"frequencies along the encoding dimension"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c5f60f9f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:36.384358Z",
"iopub.status.busy": "2023-08-18T19:30:36.383531Z",
"iopub.status.idle": "2023-08-18T19:30:36.858217Z",
"shell.execute_reply": "2023-08-18T19:30:36.857049Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"P = P[0, :, :].unsqueeze(0).unsqueeze(0)\n",
"d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',\n",
" ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}