{ "cells": [ { "cell_type": "markdown", "id": "b510dd3f", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Bidirectional Encoder Representations from Transformers (BERT)\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "ffa5c8df", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:51.331685Z", "iopub.status.busy": "2023-08-18T19:31:51.331049Z", "iopub.status.idle": "2023-08-18T19:31:54.812815Z", "shell.execute_reply": "2023-08-18T19:31:54.811897Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "fe8bcf59", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Input Representation" ] }, { "cell_type": "code", "execution_count": 2, "id": "3c018a43", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:54.816440Z", "iopub.status.busy": "2023-08-18T19:31:54.816038Z", "iopub.status.idle": "2023-08-18T19:31:54.823105Z", "shell.execute_reply": "2023-08-18T19:31:54.822123Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_tokens_and_segments(tokens_a, tokens_b=None):\n", " \"\"\"Get tokens of the BERT input sequence and their segment IDs.\"\"\"\n", " tokens = [''] + tokens_a + ['']\n", " segments = [0] * (len(tokens_a) + 2)\n", " if tokens_b is not None:\n", " tokens += tokens_b + ['']\n", " segments += [1] * (len(tokens_b) + 1)\n", " return tokens, segments" ] }, { "cell_type": "markdown", "id": "45ba910d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "`BERTEncoder` class" ] }, { "cell_type": "code", "execution_count": 3, "id": "a90baa53", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:54.828159Z", "iopub.status.busy": "2023-08-18T19:31:54.827512Z", "iopub.status.idle": "2023-08-18T19:31:54.835645Z", "shell.execute_reply": "2023-08-18T19:31:54.834698Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BERTEncoder(nn.Module):\n", " \"\"\"BERT encoder.\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,\n", " num_blks, dropout, max_len=1000, **kwargs):\n", " super(BERTEncoder, self).__init__(**kwargs)\n", " self.token_embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.segment_embedding = nn.Embedding(2, num_hiddens)\n", " self.blks = nn.Sequential()\n", " for i in range(num_blks):\n", " self.blks.add_module(f\"{i}\", d2l.TransformerEncoderBlock(\n", " num_hiddens, ffn_num_hiddens, num_heads, dropout, True))\n", " self.pos_embedding = nn.Parameter(torch.randn(1, max_len,\n", " num_hiddens))\n", "\n", " def forward(self, tokens, segments, valid_lens):\n", " X = self.token_embedding(tokens) + self.segment_embedding(segments)\n", " X = X + self.pos_embedding[:, :X.shape[1], :]\n", " for blk in self.blks:\n", " X = blk(X, valid_lens)\n", " return X" ] }, { "cell_type": "markdown", "id": "1f8d4e1d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Inference of `BERTEncoder`" ] }, { "cell_type": "code", "execution_count": 5, "id": "56903d71", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:54.990720Z", "iopub.status.busy": "2023-08-18T19:31:54.989911Z", "iopub.status.idle": "2023-08-18T19:31:55.148261Z", "shell.execute_reply": "2023-08-18T19:31:55.147176Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 8, 768])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4\n", "ffn_num_input, num_blks, dropout = 768, 2, 0.2\n", "encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads,\n", " num_blks, dropout)\n", "\n", "tokens = torch.randint(0, vocab_size, (2, 8))\n", "segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])\n", "encoded_X = encoder(tokens, segments, None)\n", "encoded_X.shape" ] }, { "cell_type": "markdown", "id": "7700375b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Masked Language Modeling" ] }, { "cell_type": "code", "execution_count": 6, "id": "a8614e46", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.154295Z", "iopub.status.busy": "2023-08-18T19:31:55.153400Z", "iopub.status.idle": "2023-08-18T19:31:55.162155Z", "shell.execute_reply": "2023-08-18T19:31:55.161271Z" }, "origin_pos": 16, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class MaskLM(nn.Module):\n", " \"\"\"The masked language model task of BERT.\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, **kwargs):\n", " super(MaskLM, self).__init__(**kwargs)\n", " self.mlp = nn.Sequential(nn.LazyLinear(num_hiddens),\n", " nn.ReLU(),\n", " nn.LayerNorm(num_hiddens),\n", " nn.LazyLinear(vocab_size))\n", "\n", " def forward(self, X, pred_positions):\n", " num_pred_positions = pred_positions.shape[1]\n", " pred_positions = pred_positions.reshape(-1)\n", " batch_size = X.shape[0]\n", " batch_idx = torch.arange(0, batch_size)\n", " batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)\n", " masked_X = X[batch_idx, pred_positions]\n", " masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))\n", " mlm_Y_hat = self.mlp(masked_X)\n", " return mlm_Y_hat" ] }, { "cell_type": "markdown", "id": "ceb49f61", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The forward inference of `MaskLM`" ] }, { "cell_type": "code", "execution_count": 7, "id": "6b3fc7d6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.166188Z", "iopub.status.busy": "2023-08-18T19:31:55.165836Z", "iopub.status.idle": "2023-08-18T19:31:55.273706Z", "shell.execute_reply": "2023-08-18T19:31:55.272680Z" }, "origin_pos": 19, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 10000])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlm = MaskLM(vocab_size, num_hiddens)\n", "mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])\n", "mlm_Y_hat = mlm(encoded_X, mlm_positions)\n", "mlm_Y_hat.shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "8d85768b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.278830Z", "iopub.status.busy": "2023-08-18T19:31:55.278109Z", "iopub.status.idle": "2023-08-18T19:31:55.288546Z", "shell.execute_reply": "2023-08-18T19:31:55.287485Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([6])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])\n", "loss = nn.CrossEntropyLoss(reduction='none')\n", "mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))\n", "mlm_l.shape" ] }, { "cell_type": "markdown", "id": "cf2bde58", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Next Sentence Prediction" ] }, { "cell_type": "code", "execution_count": 9, "id": "c1951876", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.292806Z", "iopub.status.busy": "2023-08-18T19:31:55.291904Z", "iopub.status.idle": "2023-08-18T19:31:55.298328Z", "shell.execute_reply": "2023-08-18T19:31:55.297464Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class NextSentencePred(nn.Module):\n", " \"\"\"The next sentence prediction task of BERT.\"\"\"\n", " def __init__(self, **kwargs):\n", " super(NextSentencePred, self).__init__(**kwargs)\n", " self.output = nn.LazyLinear(2)\n", "\n", " def forward(self, X):\n", " return self.output(X)" ] }, { "cell_type": "markdown", "id": "57961ff1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The forward inference of an `NextSentencePred`" ] }, { "cell_type": "code", "execution_count": 10, "id": "aba0cce5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.302539Z", "iopub.status.busy": "2023-08-18T19:31:55.301869Z", "iopub.status.idle": "2023-08-18T19:31:55.310590Z", "shell.execute_reply": "2023-08-18T19:31:55.309427Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 2])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_X = torch.flatten(encoded_X, start_dim=1)\n", "nsp = NextSentencePred()\n", "nsp_Y_hat = nsp(encoded_X)\n", "nsp_Y_hat.shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "ba504299", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.314489Z", "iopub.status.busy": "2023-08-18T19:31:55.313852Z", "iopub.status.idle": "2023-08-18T19:31:55.321256Z", "shell.execute_reply": "2023-08-18T19:31:55.320193Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([2])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nsp_y = torch.tensor([0, 1])\n", "nsp_l = loss(nsp_Y_hat, nsp_y)\n", "nsp_l.shape" ] }, { "cell_type": "markdown", "id": "9fd4cc06", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Putting It All Together" ] }, { "cell_type": "code", "execution_count": 12, "id": "73c331cd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:31:55.325106Z", "iopub.status.busy": "2023-08-18T19:31:55.324530Z", "iopub.status.idle": "2023-08-18T19:31:55.333301Z", "shell.execute_reply": "2023-08-18T19:31:55.332018Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "class BERTModel(nn.Module):\n", " \"\"\"The BERT model.\"\"\"\n", " def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,\n", " num_heads, num_blks, dropout, max_len=1000):\n", " super(BERTModel, self).__init__()\n", " self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,\n", " num_heads, num_blks, dropout,\n", " max_len=max_len)\n", " self.hidden = nn.Sequential(nn.LazyLinear(num_hiddens),\n", " nn.Tanh())\n", " self.mlm = MaskLM(vocab_size, num_hiddens)\n", " self.nsp = NextSentencePred()\n", "\n", " def forward(self, tokens, segments, valid_lens=None, pred_positions=None):\n", " encoded_X = self.encoder(tokens, segments, valid_lens)\n", " if pred_positions is not None:\n", " mlm_Y_hat = self.mlm(encoded_X, pred_positions)\n", " else:\n", " mlm_Y_hat = None\n", " nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))\n", " return encoded_X, mlm_Y_hat, nsp_Y_hat" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }