{
"cells": [
{
"cell_type": "markdown",
"id": "81234860",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 自注意力和位置编码\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1f68f3c6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:34.234618Z",
"iopub.status.busy": "2023-08-18T07:01:34.233587Z",
"iopub.status.idle": "2023-08-18T07:01:37.175197Z",
"shell.execute_reply": "2023-08-18T07:01:37.174050Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "942f6c8e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"自注意力"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "91993c5f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.181087Z",
"iopub.status.busy": "2023-08-18T07:01:37.180270Z",
"iopub.status.idle": "2023-08-18T07:01:37.209854Z",
"shell.execute_reply": "2023-08-18T07:01:37.208705Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"MultiHeadAttention(\n",
" (attention): DotProductAttention(\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
")"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_hiddens, num_heads = 100, 5\n",
"attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
" num_hiddens, num_heads, 0.5)\n",
"attention.eval()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "05a56888",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.214732Z",
"iopub.status.busy": "2023-08-18T07:01:37.214099Z",
"iopub.status.idle": "2023-08-18T07:01:37.231099Z",
"shell.execute_reply": "2023-08-18T07:01:37.229941Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 4, 100])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n",
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
"attention(X, X, X, valid_lens).shape"
]
},
{
"cell_type": "markdown",
"id": "dfba3e26",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"位置编码"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a1520381",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.236150Z",
"iopub.status.busy": "2023-08-18T07:01:37.235749Z",
"iopub.status.idle": "2023-08-18T07:01:37.246341Z",
"shell.execute_reply": "2023-08-18T07:01:37.245419Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
" \"\"\"位置编码\"\"\"\n",
" def __init__(self, num_hiddens, dropout, max_len=1000):\n",
" super(PositionalEncoding, self).__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.P = torch.zeros((1, max_len, num_hiddens))\n",
" X = torch.arange(max_len, dtype=torch.float32).reshape(\n",
" -1, 1) / torch.pow(10000, torch.arange(\n",
" 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n",
" self.P[:, :, 0::2] = torch.sin(X)\n",
" self.P[:, :, 1::2] = torch.cos(X)\n",
"\n",
" def forward(self, X):\n",
" X = X + self.P[:, :X.shape[1], :].to(X.device)\n",
" return self.dropout(X)"
]
},
{
"cell_type": "markdown",
"id": "c553976d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"行代表词元在序列中的位置,列代表位置编码的不同维度"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2530db11",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.253441Z",
"iopub.status.busy": "2023-08-18T07:01:37.251675Z",
"iopub.status.idle": "2023-08-18T07:01:37.511460Z",
"shell.execute_reply": "2023-08-18T07:01:37.510281Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"encoding_dim, num_steps = 32, 60\n",
"pos_encoding = PositionalEncoding(encoding_dim, 0)\n",
"pos_encoding.eval()\n",
"X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))\n",
"P = pos_encoding.P[:, :X.shape[1], :]\n",
"d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',\n",
" figsize=(6, 2.5), legend=[\"Col %d\" % d for d in torch.arange(6, 10)])"
]
},
{
"cell_type": "markdown",
"id": "df574435",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"二进制表示"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "07196b9a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.516113Z",
"iopub.status.busy": "2023-08-18T07:01:37.515203Z",
"iopub.status.idle": "2023-08-18T07:01:37.523367Z",
"shell.execute_reply": "2023-08-18T07:01:37.520554Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0的二进制是:000\n",
"1的二进制是:001\n",
"2的二进制是:010\n",
"3的二进制是:011\n",
"4的二进制是:100\n",
"5的二进制是:101\n",
"6的二进制是:110\n",
"7的二进制是:111\n"
]
}
],
"source": [
"for i in range(8):\n",
" print(f'{i}的二进制是:{i:>03b}')"
]
},
{
"cell_type": "markdown",
"id": "87817add",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"在编码维度上降低频率"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fb689860",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.528541Z",
"iopub.status.busy": "2023-08-18T07:01:37.527891Z",
"iopub.status.idle": "2023-08-18T07:01:37.784120Z",
"shell.execute_reply": "2023-08-18T07:01:37.782997Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"P = P[0, :, :].unsqueeze(0).unsqueeze(0)\n",
"d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',\n",
" ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}