{
"cells": [
{
"cell_type": "markdown",
"id": "248bcc01",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Self-Attention and Positional Encoding\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b2969e34",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:32.804452Z",
"iopub.status.busy": "2023-08-18T19:30:32.803811Z",
"iopub.status.idle": "2023-08-18T19:30:35.929844Z",
"shell.execute_reply": "2023-08-18T19:30:35.926598Z"
},
"origin_pos": 3,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "39ee3522",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Self-Attention"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "13743b61",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.935527Z",
"iopub.status.busy": "2023-08-18T19:30:35.934433Z",
"iopub.status.idle": "2023-08-18T19:30:35.974177Z",
"shell.execute_reply": "2023-08-18T19:30:35.973091Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"num_hiddens, num_heads = 100, 5\n",
"attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)\n",
"batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n",
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
"d2l.check_shape(attention(X, X, X, valid_lens),\n",
" (batch_size, num_queries, num_hiddens))"
]
},
{
"cell_type": "markdown",
"id": "525745e4",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Positional Encoding"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3eb1b5ef",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.979909Z",
"iopub.status.busy": "2023-08-18T19:30:35.978770Z",
"iopub.status.idle": "2023-08-18T19:30:35.987465Z",
"shell.execute_reply": "2023-08-18T19:30:35.986155Z"
},
"origin_pos": 16,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module): \n",
" \"\"\"Positional encoding.\"\"\"\n",
" def __init__(self, num_hiddens, dropout, max_len=1000):\n",
" super().__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.P = torch.zeros((1, max_len, num_hiddens))\n",
" X = torch.arange(max_len, dtype=torch.float32).reshape(\n",
" -1, 1) / torch.pow(10000, torch.arange(\n",
" 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n",
" self.P[:, :, 0::2] = torch.sin(X)\n",
" self.P[:, :, 1::2] = torch.cos(X)\n",
"\n",
" def forward(self, X):\n",
" X = X + self.P[:, :X.shape[1], :].to(X.device)\n",
" return self.dropout(X)"
]
},
{
"cell_type": "markdown",
"id": "a0548d5d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Rows correspond to positions within a sequence\n",
"and columns represent different positional encoding dimensions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "51320f4e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:30:35.991251Z",
"iopub.status.busy": "2023-08-18T19:30:35.990632Z",
"iopub.status.idle": "2023-08-18T19:30:36.368109Z",
"shell.execute_reply": "2023-08-18T19:30:36.366973Z"
},
"origin_pos": 21,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"