{
"cells": [
{
"cell_type": "markdown",
"id": "81234860",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 自注意力和位置编码\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1f68f3c6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:34.234618Z",
"iopub.status.busy": "2023-08-18T07:01:34.233587Z",
"iopub.status.idle": "2023-08-18T07:01:37.175197Z",
"shell.execute_reply": "2023-08-18T07:01:37.174050Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "942f6c8e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"自注意力"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "91993c5f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.181087Z",
"iopub.status.busy": "2023-08-18T07:01:37.180270Z",
"iopub.status.idle": "2023-08-18T07:01:37.209854Z",
"shell.execute_reply": "2023-08-18T07:01:37.208705Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"MultiHeadAttention(\n",
" (attention): DotProductAttention(\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" )\n",
" (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
" (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
")"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_hiddens, num_heads = 100, 5\n",
"attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
" num_hiddens, num_heads, 0.5)\n",
"attention.eval()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "05a56888",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.214732Z",
"iopub.status.busy": "2023-08-18T07:01:37.214099Z",
"iopub.status.idle": "2023-08-18T07:01:37.231099Z",
"shell.execute_reply": "2023-08-18T07:01:37.229941Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 4, 100])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])\n",
"X = torch.ones((batch_size, num_queries, num_hiddens))\n",
"attention(X, X, X, valid_lens).shape"
]
},
{
"cell_type": "markdown",
"id": "dfba3e26",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"位置编码"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a1520381",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.236150Z",
"iopub.status.busy": "2023-08-18T07:01:37.235749Z",
"iopub.status.idle": "2023-08-18T07:01:37.246341Z",
"shell.execute_reply": "2023-08-18T07:01:37.245419Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class PositionalEncoding(nn.Module):\n",
" \"\"\"位置编码\"\"\"\n",
" def __init__(self, num_hiddens, dropout, max_len=1000):\n",
" super(PositionalEncoding, self).__init__()\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.P = torch.zeros((1, max_len, num_hiddens))\n",
" X = torch.arange(max_len, dtype=torch.float32).reshape(\n",
" -1, 1) / torch.pow(10000, torch.arange(\n",
" 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)\n",
" self.P[:, :, 0::2] = torch.sin(X)\n",
" self.P[:, :, 1::2] = torch.cos(X)\n",
"\n",
" def forward(self, X):\n",
" X = X + self.P[:, :X.shape[1], :].to(X.device)\n",
" return self.dropout(X)"
]
},
{
"cell_type": "markdown",
"id": "c553976d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"行代表词元在序列中的位置,列代表位置编码的不同维度"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2530db11",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:01:37.253441Z",
"iopub.status.busy": "2023-08-18T07:01:37.251675Z",
"iopub.status.idle": "2023-08-18T07:01:37.511460Z",
"shell.execute_reply": "2023-08-18T07:01:37.510281Z"
},
"origin_pos": 19,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"