{
"cells": [
{
"cell_type": "markdown",
"id": "f19011b1",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Pretraining BERT\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f710ca0b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:39.682439Z",
"iopub.status.busy": "2023-08-18T19:41:39.681740Z",
"iopub.status.idle": "2023-08-18T19:41:48.235030Z",
"shell.execute_reply": "2023-08-18T19:41:48.233725Z"
},
"origin_pos": 4,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"batch_size, max_len = 512, 64\n",
"train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)"
]
},
{
"cell_type": "markdown",
"id": "2f3a414f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"A small BERT, using 2 layers, 128 hidden units, and 2 self-attention heads"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bfbf346b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.240739Z",
"iopub.status.busy": "2023-08-18T19:41:48.239930Z",
"iopub.status.idle": "2023-08-18T19:41:48.297798Z",
"shell.execute_reply": "2023-08-18T19:41:48.296836Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = d2l.BERTModel(len(vocab), num_hiddens=128,\n",
" ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)\n",
"devices = d2l.try_all_gpus()\n",
"loss = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"id": "f9dc526d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Computes the loss for both the masked language modeling and next sentence prediction tasks"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ba08f5b3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.302735Z",
"iopub.status.busy": "2023-08-18T19:41:48.302136Z",
"iopub.status.idle": "2023-08-18T19:41:48.307953Z",
"shell.execute_reply": "2023-08-18T19:41:48.307116Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n",
" segments_X, valid_lens_x,\n",
" pred_positions_X, mlm_weights_X,\n",
" mlm_Y, nsp_y):\n",
" _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n",
" valid_lens_x.reshape(-1),\n",
" pred_positions_X)\n",
" mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n",
" mlm_weights_X.reshape(-1, 1)\n",
" mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n",
" nsp_l = loss(nsp_Y_hat, nsp_y)\n",
" l = mlm_l + nsp_l\n",
" return mlm_l, nsp_l, l"
]
},
{
"cell_type": "markdown",
"id": "e1825deb",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Pretrain BERT (`net`) on the WikiText-2 (`train_iter`) dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4ad41e4e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.326142Z",
"iopub.status.busy": "2023-08-18T19:41:48.325308Z",
"iopub.status.idle": "2023-08-18T19:42:05.912894Z",
"shell.execute_reply": "2023-08-18T19:42:05.911726Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLM loss 5.885, NSP loss 0.760\n",
"4413.2 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):\n",
" net(*next(iter(train_iter))[:4])\n",
" net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
" trainer = torch.optim.Adam(net.parameters(), lr=0.01)\n",
" step, timer = 0, d2l.Timer()\n",
" animator = d2l.Animator(xlabel='step', ylabel='loss',\n",
" xlim=[1, num_steps], legend=['mlm', 'nsp'])\n",
" metric = d2l.Accumulator(4)\n",
" num_steps_reached = False\n",
" while step < num_steps and not num_steps_reached:\n",
" for tokens_X, segments_X, valid_lens_x, pred_positions_X,\\\n",
" mlm_weights_X, mlm_Y, nsp_y in train_iter:\n",
" tokens_X = tokens_X.to(devices[0])\n",
" segments_X = segments_X.to(devices[0])\n",
" valid_lens_x = valid_lens_x.to(devices[0])\n",
" pred_positions_X = pred_positions_X.to(devices[0])\n",
" mlm_weights_X = mlm_weights_X.to(devices[0])\n",
" mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])\n",
" trainer.zero_grad()\n",
" timer.start()\n",
" mlm_l, nsp_l, l = _get_batch_loss_bert(\n",
" net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,\n",
" pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)\n",
" l.backward()\n",
" trainer.step()\n",
" metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)\n",
" timer.stop()\n",
" animator.add(step + 1,\n",
" (metric[0] / metric[3], metric[1] / metric[3]))\n",
" step += 1\n",
" if step == num_steps:\n",
" num_steps_reached = True\n",
" break\n",
"\n",
" print(f'MLM loss {metric[0] / metric[3]:.3f}, '\n",
" f'NSP loss {metric[1] / metric[3]:.3f}')\n",
" print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '\n",
" f'{str(devices)}')\n",
"\n",
"train_bert(train_iter, net, loss, len(vocab), devices, 50)"
]
},
{
"cell_type": "markdown",
"id": "1a9cdae3",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Representing Text with BERT"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e17d97e2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:42:05.917554Z",
"iopub.status.busy": "2023-08-18T19:42:05.916794Z",
"iopub.status.idle": "2023-08-18T19:42:05.924854Z",
"shell.execute_reply": "2023-08-18T19:42:05.923643Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_bert_encoding(net, tokens_a, tokens_b=None):\n",
" tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)\n",
" token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)\n",
" segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)\n",
" valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)\n",
" encoded_X, _, _ = net(token_ids, segments, valid_len)\n",
" return encoded_X"
]
},
{
"cell_type": "markdown",
"id": "f1d39471",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Consider the sentence \"a crane is flying\""
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0cf00a35",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:42:05.930210Z",
"iopub.status.busy": "2023-08-18T19:42:05.929167Z",
"iopub.status.idle": "2023-08-18T19:42:06.033511Z",
"shell.execute_reply": "2023-08-18T19:42:06.032431Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 6, 128]),\n",
" torch.Size([1, 128]),\n",
" tensor([0.8414, 1.4830, 0.8226], device='cuda:0', grad_fn=))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens_a = ['a', 'crane', 'is', 'flying']\n",
"encoded_text = get_bert_encoding(net, tokens_a)\n",
"encoded_text_cls = encoded_text[:, 0, :]\n",
"encoded_text_crane = encoded_text[:, 2, :]\n",
"encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]"
]
},
{
"cell_type": "markdown",
"id": "2fe5b688",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Now consider a sentence pair\n",
"\"a crane driver came\" and \"he just left\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8c94b0fd",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:42:06.038841Z",
"iopub.status.busy": "2023-08-18T19:42:06.037905Z",
"iopub.status.idle": "2023-08-18T19:42:06.052408Z",
"shell.execute_reply": "2023-08-18T19:42:06.051278Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 10, 128]),\n",
" torch.Size([1, 128]),\n",
" tensor([0.0430, 1.6132, 0.0437], device='cuda:0', grad_fn=))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']\n",
"encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)\n",
"encoded_pair_cls = encoded_pair[:, 0, :]\n",
"encoded_pair_crane = encoded_pair[:, 2, :]\n",
"encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}