{ "cells": [ { "cell_type": "markdown", "id": "f19011b1", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Pretraining BERT\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "f710ca0b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:41:39.682439Z", "iopub.status.busy": "2023-08-18T19:41:39.681740Z", "iopub.status.idle": "2023-08-18T19:41:48.235030Z", "shell.execute_reply": "2023-08-18T19:41:48.233725Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "\n", "batch_size, max_len = 512, 64\n", "train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)" ] }, { "cell_type": "markdown", "id": "2f3a414f", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "A small BERT, using 2 layers, 128 hidden units, and 2 self-attention heads" ] }, { "cell_type": "code", "execution_count": 3, "id": "bfbf346b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:41:48.240739Z", "iopub.status.busy": "2023-08-18T19:41:48.239930Z", "iopub.status.idle": "2023-08-18T19:41:48.297798Z", "shell.execute_reply": "2023-08-18T19:41:48.296836Z" }, "origin_pos": 7, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "net = d2l.BERTModel(len(vocab), num_hiddens=128,\n", " ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)\n", "devices = d2l.try_all_gpus()\n", "loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "id": "f9dc526d", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Computes the loss for both the masked language modeling and next sentence prediction tasks" ] }, { "cell_type": "code", "execution_count": 4, "id": "ba08f5b3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:41:48.302735Z", "iopub.status.busy": "2023-08-18T19:41:48.302136Z", "iopub.status.idle": "2023-08-18T19:41:48.307953Z", "shell.execute_reply": "2023-08-18T19:41:48.307116Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n", " segments_X, valid_lens_x,\n", " pred_positions_X, mlm_weights_X,\n", " mlm_Y, nsp_y):\n", " _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n", " valid_lens_x.reshape(-1),\n", " pred_positions_X)\n", " mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n", " mlm_weights_X.reshape(-1, 1)\n", " mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n", " nsp_l = loss(nsp_Y_hat, nsp_y)\n", " l = mlm_l + nsp_l\n", " return mlm_l, nsp_l, l" ] }, { "cell_type": "markdown", "id": "e1825deb", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Pretrain BERT (`net`) on the WikiText-2 (`train_iter`) dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "4ad41e4e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:41:48.326142Z", "iopub.status.busy": "2023-08-18T19:41:48.325308Z", "iopub.status.idle": "2023-08-18T19:42:05.912894Z", "shell.execute_reply": "2023-08-18T19:42:05.911726Z" }, "origin_pos": 15, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLM loss 5.885, NSP loss 0.760\n", "4413.2 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:42:05.863396\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n", " net(*next(iter(train_iter))[:4])\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n", " step, timer = 0, d2l.Timer()\n", " animator = d2l.Animator(xlabel='step', ylabel='loss',\n", " xlim=[1, num_steps], legend=['mlm', 'nsp'])\n", " metric = d2l.Accumulator(4)\n", " num_steps_reached = False\n", " while step < num_steps and not num_steps_reached:\n", " for tokens_X, segments_X, valid_lens_x, pred_positions_X,\\\n", " mlm_weights_X, mlm_Y, nsp_y in train_iter:\n", " tokens_X = tokens_X.to(devices[0])\n", " segments_X = segments_X.to(devices[0])\n", " valid_lens_x = valid_lens_x.to(devices[0])\n", " pred_positions_X = pred_positions_X.to(devices[0])\n", " mlm_weights_X = mlm_weights_X.to(devices[0])\n", " mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n", " trainer.zero_grad()\n", " timer.start()\n", " mlm_l, nsp_l, l = _get_batch_loss_bert(\n", " net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,\n", " pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n", " l.backward()\n", " trainer.step()\n", " metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n", " timer.stop()\n", " animator.add(step + 1,\n", " (metric[0] / metric[3], metric[1] / metric[3]))\n", " step += 1\n", " if step == num_steps:\n", " num_steps_reached = True\n", " break\n", "\n", " print(f'MLM loss {metric[0] / metric[3]:.3f}, '\n", " f'NSP loss {metric[1] / metric[3]:.3f}')\n", " print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '\n", " f'{str(devices)}')\n", "\n", "train_bert(train_iter, net, loss, len(vocab), devices, 50)" ] }, { "cell_type": "markdown", "id": "1a9cdae3", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Representing Text with BERT" ] }, { "cell_type": "code", "execution_count": 7, "id": "e17d97e2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:42:05.917554Z", "iopub.status.busy": "2023-08-18T19:42:05.916794Z", "iopub.status.idle": "2023-08-18T19:42:05.924854Z", "shell.execute_reply": "2023-08-18T19:42:05.923643Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_bert_encoding(net, tokens_a, tokens_b=None):\n", " tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n", " token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n", " segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n", " valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n", " encoded_X, _, _ = net(token_ids, segments, valid_len)\n", " return encoded_X" ] }, { "cell_type": "markdown", "id": "f1d39471", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Consider the sentence \"a crane is flying\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "0cf00a35", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:42:05.930210Z", "iopub.status.busy": "2023-08-18T19:42:05.929167Z", "iopub.status.idle": "2023-08-18T19:42:06.033511Z", "shell.execute_reply": "2023-08-18T19:42:06.032431Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 6, 128]),\n", " torch.Size([1, 128]),\n", " tensor([0.8414, 1.4830, 0.8226], device='cuda:0', grad_fn=))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_a = ['a', 'crane', 'is', 'flying']\n", "encoded_text = get_bert_encoding(net, tokens_a)\n", "encoded_text_cls = encoded_text[:, 0, :]\n", "encoded_text_crane = encoded_text[:, 2, :]\n", "encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]" ] }, { "cell_type": "markdown", "id": "2fe5b688", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Now consider a sentence pair\n", "\"a crane driver came\" and \"he just left\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "8c94b0fd", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:42:06.038841Z", "iopub.status.busy": "2023-08-18T19:42:06.037905Z", "iopub.status.idle": "2023-08-18T19:42:06.052408Z", "shell.execute_reply": "2023-08-18T19:42:06.051278Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 10, 128]),\n", " torch.Size([1, 128]),\n", " tensor([0.0430, 1.6132, 0.0437], device='cuda:0', grad_fn=))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n", "encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n", "encoded_pair_cls = encoded_pair[:, 0, :]\n", "encoded_pair_crane = encoded_pair[:, 2, :]\n", "encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }