{ "cells": [ { "cell_type": "markdown", "id": "248bcc01", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Self-Attention and Positional Encoding\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "b2969e34", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:32.804452Z", "iopub.status.busy": "2023-08-18T19:30:32.803811Z", "iopub.status.idle": "2023-08-18T19:30:35.929844Z", "shell.execute_reply": "2023-08-18T19:30:35.926598Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "39ee3522", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Self-Attention" ] }, { "cell_type": "code", "execution_count": 2, "id": "13743b61", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.935527Z", "iopub.status.busy": "2023-08-18T19:30:35.934433Z", "iopub.status.idle": "2023-08-18T19:30:35.974177Z", "shell.execute_reply": "2023-08-18T19:30:35.973091Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)\n", "batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n", "X = torch.ones((batch_size, num_queries, num_hiddens))\n", "d2l.check_shape(attention(X, X, X, valid_lens),\n", " (batch_size, num_queries, num_hiddens))" ] }, { "cell_type": "markdown", "id": "525745e4", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Positional Encoding" ] }, { "cell_type": "code", "execution_count": 3, "id": "3eb1b5ef", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.979909Z", "iopub.status.busy": "2023-08-18T19:30:35.978770Z", "iopub.status.idle": "2023-08-18T19:30:35.987465Z", "shell.execute_reply": "2023-08-18T19:30:35.986155Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class PositionalEncoding(nn.Module): \n", " \"\"\"Positional encoding.\"\"\"\n", " def __init__(self, num_hiddens, dropout, max_len=1000):\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " self.P = torch.zeros((1, max_len, num_hiddens))\n", " X = torch.arange(max_len, dtype=torch.float32).reshape(\n", " -1, 1) / torch.pow(10000, torch.arange(\n", " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", " self.P[:, :, 0::2] = torch.sin(X)\n", " self.P[:, :, 1::2] = torch.cos(X)\n", "\n", " def forward(self, X):\n", " X = X + self.P[:, :X.shape[1], :].to(X.device)\n", " return self.dropout(X)" ] }, { "cell_type": "markdown", "id": "a0548d5d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Rows correspond to positions within a sequence\n", "and columns represent different positional encoding dimensions" ] }, { "cell_type": "code", "execution_count": 4, "id": "51320f4e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:35.991251Z", "iopub.status.busy": "2023-08-18T19:30:35.990632Z", "iopub.status.idle": "2023-08-18T19:30:36.368109Z", "shell.execute_reply": "2023-08-18T19:30:36.366973Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:36.288792\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "encoding_dim, num_steps = 32, 60\n", "pos_encoding = PositionalEncoding(encoding_dim, 0)\n", "X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n", "P = pos_encoding.P[:, :X.shape[1], :]\n", "d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n", " figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])" ] }, { "cell_type": "markdown", "id": "12388348", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The binary representations" ] }, { "cell_type": "code", "execution_count": 5, "id": "6f42d89b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:36.373921Z", "iopub.status.busy": "2023-08-18T19:30:36.373258Z", "iopub.status.idle": "2023-08-18T19:30:36.380089Z", "shell.execute_reply": "2023-08-18T19:30:36.378862Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 in binary is 000\n", "1 in binary is 001\n", "2 in binary is 010\n", "3 in binary is 011\n", "4 in binary is 100\n", "5 in binary is 101\n", "6 in binary is 110\n", "7 in binary is 111\n" ] } ], "source": [ "for i in range(8):\n", " print(f'{i} in binary is {i:>03b}')" ] }, { "cell_type": "markdown", "id": "e137cf32", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The positional encoding decreases\n", "frequencies along the encoding dimension" ] }, { "cell_type": "code", "execution_count": 6, "id": "c5f60f9f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:36.384358Z", "iopub.status.busy": "2023-08-18T19:30:36.383531Z", "iopub.status.idle": "2023-08-18T19:30:36.858217Z", "shell.execute_reply": "2023-08-18T19:30:36.857049Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:36.784791\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "P = P[0, :, :].unsqueeze(0).unsqueeze(0)\n", "d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',\n", " ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }