{ "cells": [ { "cell_type": "markdown", "id": "e90392f7", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# The Transformer Architecture\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "ee18893c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:06.687415Z", "iopub.status.busy": "2023-08-18T19:50:06.687094Z", "iopub.status.idle": "2023-08-18T19:50:09.889628Z", "shell.execute_reply": "2023-08-18T19:50:09.888444Z" }, "origin_pos": 3, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import math\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "533fdb62", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Positionwise Feed-Forward Networks" ] }, { "cell_type": "code", "execution_count": 2, "id": "623f67ee", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.894092Z", "iopub.status.busy": "2023-08-18T19:50:09.893416Z", "iopub.status.idle": "2023-08-18T19:50:09.899737Z", "shell.execute_reply": "2023-08-18T19:50:09.898347Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class PositionWiseFFN(nn.Module): \n", " \"\"\"The positionwise feed-forward network.\"\"\"\n", " def __init__(self, ffn_num_hiddens, ffn_num_outputs):\n", " super().__init__()\n", " self.dense1 = nn.LazyLinear(ffn_num_hiddens)\n", " self.relu = nn.ReLU()\n", " self.dense2 = nn.LazyLinear(ffn_num_outputs)\n", "\n", " def forward(self, X):\n", " return self.dense2(self.relu(self.dense1(X)))" ] }, { "cell_type": "markdown", "id": "a4b75c73", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The innermost dimension\n", "of a tensor changes" ] }, { "cell_type": "code", "execution_count": 3, "id": "c462f39f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.906345Z", "iopub.status.busy": "2023-08-18T19:50:09.905327Z", "iopub.status.idle": "2023-08-18T19:50:09.920436Z", "shell.execute_reply": "2023-08-18T19:50:09.919542Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.6300, 0.7739, 0.0278, 0.2508, -0.0519, 0.4881, -0.4105, 0.5163],\n", " [ 0.6300, 0.7739, 0.0278, 0.2508, -0.0519, 0.4881, -0.4105, 0.5163],\n", " [ 0.6300, 0.7739, 0.0278, 0.2508, -0.0519, 0.4881, -0.4105, 0.5163]],\n", " grad_fn=)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ffn = PositionWiseFFN(4, 8)\n", "ffn.eval()\n", "ffn(torch.ones((2, 3, 4)))[0]" ] }, { "cell_type": "markdown", "id": "07085efa", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Compares the normalization across different dimensions\n", "by layer normalization and batch normalization" ] }, { "cell_type": "code", "execution_count": 4, "id": "81c95717", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.926398Z", "iopub.status.busy": "2023-08-18T19:50:09.924518Z", "iopub.status.idle": "2023-08-18T19:50:09.937855Z", "shell.execute_reply": "2023-08-18T19:50:09.936657Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer norm: tensor([[-1.0000, 1.0000],\n", " [-1.0000, 1.0000]], grad_fn=) \n", "batch norm: tensor([[-1.0000, -1.0000],\n", " [ 1.0000, 1.0000]], grad_fn=)\n" ] } ], "source": [ "ln = nn.LayerNorm(2)\n", "bn = nn.LazyBatchNorm1d()\n", "X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)\n", "print('layer norm:', ln(X), '\\nbatch norm:', bn(X))" ] }, { "cell_type": "markdown", "id": "4308349f", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Using a residual connection followed by layer normalization" ] }, { "cell_type": "code", "execution_count": 5, "id": "331f12e2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.941811Z", "iopub.status.busy": "2023-08-18T19:50:09.941163Z", "iopub.status.idle": "2023-08-18T19:50:09.948019Z", "shell.execute_reply": "2023-08-18T19:50:09.946884Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class AddNorm(nn.Module): \n", " \"\"\"The residual connection followed by layer normalization.\"\"\"\n", " def __init__(self, norm_shape, dropout):\n", " super().__init__()\n", " self.dropout = nn.Dropout(dropout)\n", " self.ln = nn.LayerNorm(norm_shape)\n", "\n", " def forward(self, X, Y):\n", " return self.ln(self.dropout(Y) + X)" ] }, { "cell_type": "markdown", "id": "f4b454c0", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The output tensor also has the same shape after the addition operation" ] }, { "cell_type": "code", "execution_count": 6, "id": "d3835d18", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.951769Z", "iopub.status.busy": "2023-08-18T19:50:09.951126Z", "iopub.status.idle": "2023-08-18T19:50:09.957096Z", "shell.execute_reply": "2023-08-18T19:50:09.956115Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "add_norm = AddNorm(4, 0.5)\n", "shape = (2, 3, 4)\n", "d2l.check_shape(add_norm(torch.ones(shape), torch.ones(shape)), shape)" ] }, { "cell_type": "markdown", "id": "39166416", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "A single layer within the encoder" ] }, { "cell_type": "code", "execution_count": 7, "id": "2c7cce60", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.960576Z", "iopub.status.busy": "2023-08-18T19:50:09.960031Z", "iopub.status.idle": "2023-08-18T19:50:09.965982Z", "shell.execute_reply": "2023-08-18T19:50:09.965069Z" }, "origin_pos": 33, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TransformerEncoderBlock(nn.Module): \n", " \"\"\"The Transformer encoder block.\"\"\"\n", " def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,\n", " use_bias=False):\n", " super().__init__()\n", " self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,\n", " dropout, use_bias)\n", " self.addnorm1 = AddNorm(num_hiddens, dropout)\n", " self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)\n", " self.addnorm2 = AddNorm(num_hiddens, dropout)\n", "\n", " def forward(self, X, valid_lens):\n", " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", " return self.addnorm2(Y, self.ffn(Y))" ] }, { "cell_type": "markdown", "id": "0e688bde", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "No layer in the Transformer encoder\n", "changes the shape of its input" ] }, { "cell_type": "code", "execution_count": 8, "id": "9aefd8d7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.969308Z", "iopub.status.busy": "2023-08-18T19:50:09.968775Z", "iopub.status.idle": "2023-08-18T19:50:09.982374Z", "shell.execute_reply": "2023-08-18T19:50:09.981506Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "X = torch.ones((2, 100, 24))\n", "valid_lens = torch.tensor([3, 2])\n", "encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)\n", "encoder_blk.eval()\n", "d2l.check_shape(encoder_blk(X, valid_lens), X.shape)" ] }, { "cell_type": "markdown", "id": "691630cf", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Transformer encoder" ] }, { "cell_type": "code", "execution_count": 9, "id": "bdcabb11", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.986032Z", "iopub.status.busy": "2023-08-18T19:50:09.985758Z", "iopub.status.idle": "2023-08-18T19:50:09.993370Z", "shell.execute_reply": "2023-08-18T19:50:09.992550Z" }, "origin_pos": 43, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TransformerEncoder(d2l.Encoder): \n", " \"\"\"The Transformer encoder.\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,\n", " num_heads, num_blks, dropout, use_bias=False):\n", " super().__init__()\n", " self.num_hiddens = num_hiddens\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_blks):\n", " self.blks.add_module(\"block\"+str(i), TransformerEncoderBlock(\n", " num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))\n", "\n", " def forward(self, X, valid_lens):\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self.attention_weights = [None] * len(self.blks)\n", " for i, blk in enumerate(self.blks):\n", " X = blk(X, valid_lens)\n", " self.attention_weights[\n", " i] = blk.attention.attention.attention_weights\n", " return X" ] }, { "cell_type": "markdown", "id": "424c1cc5", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Create a two-layer Transformer encoder" ] }, { "cell_type": "code", "execution_count": 10, "id": "e09106b4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:09.996457Z", "iopub.status.busy": "2023-08-18T19:50:09.996186Z", "iopub.status.idle": "2023-08-18T19:50:10.014181Z", "shell.execute_reply": "2023-08-18T19:50:10.013041Z" }, "origin_pos": 48, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)\n", "d2l.check_shape(encoder(torch.ones((2, 100), dtype=torch.long), valid_lens),\n", " (2, 100, 24))" ] }, { "cell_type": "markdown", "id": "953a02ea", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The Transformer decoder\n", "is composed of multiple identical layers" ] }, { "cell_type": "code", "execution_count": 11, "id": "23664727", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:10.017719Z", "iopub.status.busy": "2023-08-18T19:50:10.017425Z", "iopub.status.idle": "2023-08-18T19:50:10.027060Z", "shell.execute_reply": "2023-08-18T19:50:10.026020Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TransformerDecoderBlock(nn.Module):\n", " def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):\n", " super().__init__()\n", " self.i = i\n", " self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,\n", " dropout)\n", " self.addnorm1 = AddNorm(num_hiddens, dropout)\n", " self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,\n", " dropout)\n", " self.addnorm2 = AddNorm(num_hiddens, dropout)\n", " self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)\n", " self.addnorm3 = AddNorm(num_hiddens, dropout)\n", "\n", " def forward(self, X, state):\n", " enc_outputs, enc_valid_lens = state[0], state[1]\n", " if state[2][self.i] is None:\n", " key_values = X\n", " else:\n", " key_values = torch.cat((state[2][self.i], X), dim=1)\n", " state[2][self.i] = key_values\n", " if self.training:\n", " batch_size, num_steps, _ = X.shape\n", " dec_valid_lens = torch.arange(\n", " 1, num_steps + 1, device=X.device).repeat(batch_size, 1)\n", " else:\n", " dec_valid_lens = None\n", " X2 = self.attention1(X, key_values, key_values, dec_valid_lens)\n", " Y = self.addnorm1(X, X2)\n", " Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)\n", " Z = self.addnorm2(Y, Y2)\n", " return self.addnorm3(Z, self.ffn(Z)), state" ] }, { "cell_type": "markdown", "id": "f1025f42", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The feature dimension (`num_hiddens`) of the decoder is\n", "the same as that of the encoder" ] }, { "cell_type": "code", "execution_count": 12, "id": "1f487464", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:10.030357Z", "iopub.status.busy": "2023-08-18T19:50:10.030070Z", "iopub.status.idle": "2023-08-18T19:50:10.048172Z", "shell.execute_reply": "2023-08-18T19:50:10.046972Z" }, "origin_pos": 58, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)\n", "X = torch.ones((2, 100, 24))\n", "state = [encoder_blk(X, valid_lens), valid_lens, [None]]\n", "d2l.check_shape(decoder_blk(X, state)[0], X.shape)" ] }, { "cell_type": "markdown", "id": "ab792ba3", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Construct the entire Transformer decoder" ] }, { "cell_type": "code", "execution_count": 13, "id": "38ebb1e7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:10.051657Z", "iopub.status.busy": "2023-08-18T19:50:10.051318Z", "iopub.status.idle": "2023-08-18T19:50:10.061485Z", "shell.execute_reply": "2023-08-18T19:50:10.060579Z" }, "origin_pos": 63, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class TransformerDecoder(d2l.AttentionDecoder):\n", " def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,\n", " num_blks, dropout):\n", " super().__init__()\n", " self.num_hiddens = num_hiddens\n", " self.num_blks = num_blks\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_blks):\n", " self.blks.add_module(\"block\"+str(i), TransformerDecoderBlock(\n", " num_hiddens, ffn_num_hiddens, num_heads, dropout, i))\n", " self.dense = nn.LazyLinear(vocab_size)\n", "\n", " def init_state(self, enc_outputs, enc_valid_lens):\n", " return [enc_outputs, enc_valid_lens, [None] * self.num_blks]\n", "\n", " def forward(self, X, state):\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self._attention_weights = [[None] * len(self.blks) for _ in range (2)]\n", " for i, blk in enumerate(self.blks):\n", " X, state = blk(X, state)\n", " self._attention_weights[0][\n", " i] = blk.attention1.attention.attention_weights\n", " self._attention_weights[1][\n", " i] = blk.attention2.attention.attention_weights\n", " return self.dense(X), state\n", "\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ] }, { "cell_type": "markdown", "id": "d905f877", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training" ] }, { "cell_type": "code", "execution_count": 14, "id": "74f2da96", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:10.065769Z", "iopub.status.busy": "2023-08-18T19:50:10.064767Z", "iopub.status.idle": "2023-08-18T19:50:42.759965Z", "shell.execute_reply": "2023-08-18T19:50:42.758647Z" }, "origin_pos": 67, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:50:42.628788\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = d2l.MTFraEng(batch_size=128)\n", "num_hiddens, num_blks, dropout = 256, 2, 0.2\n", "ffn_num_hiddens, num_heads = 64, 4\n", "encoder = TransformerEncoder(\n", " len(data.src_vocab), num_hiddens, ffn_num_hiddens, num_heads,\n", " num_blks, dropout)\n", "decoder = TransformerDecoder(\n", " len(data.tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads,\n", " num_blks, dropout)\n", "model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''],\n", " lr=0.001)\n", "trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)\n", "trainer.fit(model, data)" ] }, { "cell_type": "markdown", "id": "871837ed", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Translate a few English sentences" ] }, { "cell_type": "code", "execution_count": 15, "id": "06e5e238", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:42.765441Z", "iopub.status.busy": "2023-08-18T19:50:42.764512Z", "iopub.status.idle": "2023-08-18T19:50:42.852261Z", "shell.execute_reply": "2023-08-18T19:50:42.850805Z" }, "origin_pos": 69, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go . => ['va', '!'], bleu,1.000\n", "i lost . => ['je', 'perdu', '.'], bleu,0.687\n", "he's calm . => ['il', 'est', 'mouillé', '.'], bleu,0.658\n", "i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000\n" ] } ], "source": [ "engs = ['go .', 'i lost .', 'he\\'s calm .', 'i\\'m home .']\n", "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n", "preds, _ = model.predict_step(\n", " data.build(engs, fras), d2l.try_gpu(), data.num_steps)\n", "for en, fr, p in zip(engs, fras, preds):\n", " translation = []\n", " for token in data.tgt_vocab.to_tokens(p):\n", " if token == '':\n", " break\n", " translation.append(token)\n", " print(f'{en} => {translation}, bleu,'\n", " f'{d2l.bleu(\" \".join(translation), fr, k=2):.3f}')" ] }, { "cell_type": "markdown", "id": "9f7fb206", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Visualize the Transformer attention weights" ] }, { "cell_type": "code", "execution_count": 17, "id": "520c51a5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:42.931273Z", "iopub.status.busy": "2023-08-18T19:50:42.930241Z", "iopub.status.idle": "2023-08-18T19:50:44.499507Z", "shell.execute_reply": "2023-08-18T19:50:44.498627Z" }, "origin_pos": 75, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:50:44.244798\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, dec_attention_weights = model.predict_step(\n", " data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)\n", "enc_attention_weights = torch.cat(model.encoder.attention_weights, 0)\n", "shape = (num_blks, num_heads, -1, data.num_steps)\n", "enc_attention_weights = enc_attention_weights.reshape(shape)\n", "d2l.check_shape(enc_attention_weights,\n", " (num_blks, num_heads, data.num_steps, data.num_steps))\n", "\n", "d2l.show_heatmaps(\n", " enc_attention_weights.cpu(), xlabel='Key positions',\n", " ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],\n", " figsize=(7, 3.5))" ] }, { "cell_type": "markdown", "id": "f81b4e58", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "To visualize the decoder self-attention weights and the encoder--decoder attention weights,\n", "we need more data manipulations" ] }, { "cell_type": "code", "execution_count": 20, "id": "8430c053", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:44.533271Z", "iopub.status.busy": "2023-08-18T19:50:44.532389Z", "iopub.status.idle": "2023-08-18T19:50:45.954406Z", "shell.execute_reply": "2023-08-18T19:50:45.953261Z" }, "origin_pos": 83, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:50:45.695784\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dec_attention_weights_2d = [head[0].tolist()\n", " for step in dec_attention_weights\n", " for attn in step for blk in attn for head in blk]\n", "dec_attention_weights_filled = torch.tensor(\n", " pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)\n", "shape = (-1, 2, num_blks, num_heads, data.num_steps)\n", "dec_attention_weights = dec_attention_weights_filled.reshape(shape)\n", "dec_self_attention_weights, dec_inter_attention_weights = \\\n", " dec_attention_weights.permute(1, 2, 3, 0, 4)\n", "\n", "d2l.check_shape(dec_self_attention_weights,\n", " (num_blks, num_heads, data.num_steps, data.num_steps))\n", "d2l.check_shape(dec_inter_attention_weights,\n", " (num_blks, num_heads, data.num_steps, data.num_steps))\n", "\n", "d2l.show_heatmaps(\n", " dec_self_attention_weights[:, :, :, :],\n", " xlabel='Key positions', ylabel='Query positions',\n", " titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))" ] }, { "cell_type": "markdown", "id": "527e9143", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "No query from the output sequence\n", "attends to those padding tokens from the input sequence" ] }, { "cell_type": "code", "execution_count": 21, "id": "1c0b1dfe", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:50:45.958174Z", "iopub.status.busy": "2023-08-18T19:50:45.957587Z", "iopub.status.idle": "2023-08-18T19:50:47.397366Z", "shell.execute_reply": "2023-08-18T19:50:47.396481Z" }, "origin_pos": 85, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:50:47.142723\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d2l.show_heatmaps(\n", " dec_inter_attention_weights, xlabel='Key positions',\n", " ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],\n", " figsize=(7, 3.5))" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }