{
"cells": [
{
"cell_type": "markdown",
"id": "f19011b1",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Pretraining BERT\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f710ca0b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:39.682439Z",
"iopub.status.busy": "2023-08-18T19:41:39.681740Z",
"iopub.status.idle": "2023-08-18T19:41:48.235030Z",
"shell.execute_reply": "2023-08-18T19:41:48.233725Z"
},
"origin_pos": 4,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"batch_size, max_len = 512, 64\n",
"train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)"
]
},
{
"cell_type": "markdown",
"id": "2f3a414f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"A small BERT, using 2 layers, 128 hidden units, and 2 self-attention heads"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bfbf346b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.240739Z",
"iopub.status.busy": "2023-08-18T19:41:48.239930Z",
"iopub.status.idle": "2023-08-18T19:41:48.297798Z",
"shell.execute_reply": "2023-08-18T19:41:48.296836Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"net = d2l.BERTModel(len(vocab), num_hiddens=128,\n",
" ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)\n",
"devices = d2l.try_all_gpus()\n",
"loss = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "markdown",
"id": "f9dc526d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Computes the loss for both the masked language modeling and next sentence prediction tasks"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ba08f5b3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.302735Z",
"iopub.status.busy": "2023-08-18T19:41:48.302136Z",
"iopub.status.idle": "2023-08-18T19:41:48.307953Z",
"shell.execute_reply": "2023-08-18T19:41:48.307116Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,\n",
" segments_X, valid_lens_x,\n",
" pred_positions_X, mlm_weights_X,\n",
" mlm_Y, nsp_y):\n",
" _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,\n",
" valid_lens_x.reshape(-1),\n",
" pred_positions_X)\n",
" mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\\\n",
" mlm_weights_X.reshape(-1, 1)\n",
" mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)\n",
" nsp_l = loss(nsp_Y_hat, nsp_y)\n",
" l = mlm_l + nsp_l\n",
" return mlm_l, nsp_l, l"
]
},
{
"cell_type": "markdown",
"id": "e1825deb",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Pretrain BERT (`net`) on the WikiText-2 (`train_iter`) dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4ad41e4e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:41:48.326142Z",
"iopub.status.busy": "2023-08-18T19:41:48.325308Z",
"iopub.status.idle": "2023-08-18T19:42:05.912894Z",
"shell.execute_reply": "2023-08-18T19:42:05.911726Z"
},
"origin_pos": 15,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLM loss 5.885, NSP loss 0.760\n",
"4413.2 sentence pairs/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"