{ "cells": [ { "cell_type": "markdown", "id": "81234860", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# 自注意力和位置编码\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1f68f3c6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:34.234618Z", "iopub.status.busy": "2023-08-18T07:01:34.233587Z", "iopub.status.idle": "2023-08-18T07:01:37.175197Z", "shell.execute_reply": "2023-08-18T07:01:37.174050Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "942f6c8e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "自注意力" ] }, { "cell_type": "code", "execution_count": 2, "id": "91993c5f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.181087Z", "iopub.status.busy": "2023-08-18T07:01:37.180270Z", "iopub.status.idle": "2023-08-18T07:01:37.209854Z", "shell.execute_reply": "2023-08-18T07:01:37.208705Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "MultiHeadAttention(\n", " (attention): DotProductAttention(\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " )\n", " (W_q): Linear(in_features=100, out_features=100, bias=False)\n", " (W_k): Linear(in_features=100, out_features=100, bias=False)\n", " (W_v): Linear(in_features=100, out_features=100, bias=False)\n", " (W_o): Linear(in_features=100, out_features=100, bias=False)\n", ")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_hiddens, num_heads = 100, 5\n", "attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n", " num_hiddens, num_heads, 0.5)\n", "attention.eval()" ] }, { "cell_type": "code", "execution_count": 3, "id": "05a56888", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.214732Z", "iopub.status.busy": "2023-08-18T07:01:37.214099Z", "iopub.status.idle": "2023-08-18T07:01:37.231099Z", "shell.execute_reply": "2023-08-18T07:01:37.229941Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 4, 100])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n", "X = torch.ones((batch_size, num_queries, num_hiddens))\n", "attention(X, X, X, valid_lens).shape" ] }, { "cell_type": "markdown", "id": "dfba3e26", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "位置编码" ] }, { "cell_type": "code", "execution_count": 4, "id": "a1520381", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.236150Z", "iopub.status.busy": "2023-08-18T07:01:37.235749Z", "iopub.status.idle": "2023-08-18T07:01:37.246341Z", "shell.execute_reply": "2023-08-18T07:01:37.245419Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class PositionalEncoding(nn.Module):\n", " \"\"\"位置编码\"\"\"\n", " def __init__(self, num_hiddens, dropout, max_len=1000):\n", " super(PositionalEncoding, self).__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " self.P = torch.zeros((1, max_len, num_hiddens))\n", " X = torch.arange(max_len, dtype=torch.float32).reshape(\n", " -1, 1) / torch.pow(10000, torch.arange(\n", " 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n", " self.P[:, :, 0::2] = torch.sin(X)\n", " self.P[:, :, 1::2] = torch.cos(X)\n", "\n", " def forward(self, X):\n", " X = X + self.P[:, :X.shape[1], :].to(X.device)\n", " return self.dropout(X)" ] }, { "cell_type": "markdown", "id": "c553976d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "行代表词元在序列中的位置,列代表位置编码的不同维度" ] }, { "cell_type": "code", "execution_count": 5, "id": "2530db11", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.253441Z", "iopub.status.busy": "2023-08-18T07:01:37.251675Z", "iopub.status.idle": "2023-08-18T07:01:37.511460Z", "shell.execute_reply": "2023-08-18T07:01:37.510281Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:37.459076\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "encoding_dim, num_steps = 32, 60\n", "pos_encoding = PositionalEncoding(encoding_dim, 0)\n", "pos_encoding.eval()\n", "X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n", "P = pos_encoding.P[:, :X.shape[1], :]\n", "d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n", " figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])" ] }, { "cell_type": "markdown", "id": "df574435", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "二进制表示" ] }, { "cell_type": "code", "execution_count": 6, "id": "07196b9a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.516113Z", "iopub.status.busy": "2023-08-18T07:01:37.515203Z", "iopub.status.idle": "2023-08-18T07:01:37.523367Z", "shell.execute_reply": "2023-08-18T07:01:37.520554Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0的二进制是:000\n", "1的二进制是:001\n", "2的二进制是:010\n", "3的二进制是:011\n", "4的二进制是:100\n", "5的二进制是:101\n", "6的二进制是:110\n", "7的二进制是:111\n" ] } ], "source": [ "for i in range(8):\n", " print(f'{i}的二进制是:{i:>03b}')" ] }, { "cell_type": "markdown", "id": "87817add", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在编码维度上降低频率" ] }, { "cell_type": "code", "execution_count": 7, "id": "fb689860", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:01:37.528541Z", "iopub.status.busy": "2023-08-18T07:01:37.527891Z", "iopub.status.idle": "2023-08-18T07:01:37.784120Z", "shell.execute_reply": "2023-08-18T07:01:37.782997Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:01:37.711974\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "P = P[0, :, :].unsqueeze(0).unsqueeze(0)\n", "d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',\n", " ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }