{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "nterop": { "id": "44" } }, "source": [ "# TF/Keras BERT Baseline (Training/Inference)\n", "> A tutorial about how to train an NLP model with the huggingface's pretrained BERT in TF/Keras\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- categories: [notebook, kaggle, nlp]" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "nterop": { "id": "1" } }, "source": [ "This notebook shows how to train a neural network model with pre-trained BERT in Tensorflow/Keras. It is based on @xhlulu's [Disaster NLP: Keras BERT using TFHub](https://www.kaggle.com/xhlulu/disaster-nlp-keras-bert-using-tfhub\n", ") notebook and [Text Extraction with BERT](https://keras.io/examples/nlp/text_extraction_with_bert/) example at Keras.\n", "\n", "This competition is a code competition without access to internet. So we add the `transformers` tokenizer and pre-trained BERT model through Kaggle Datasets instead.\n", "\n", "Hope it helps." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "nterop": { "id": "28" } }, "source": [ "# Changelogs" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "nterop": { "id": "29" } }, "source": [ "| Version | CV Score | Public Score | Changes | Comment |\n", "|----------|----------|--------------|---------|---------|\n", "| v9 | to be updated | to be updated | use transformers' tokenizer |\n", "| v8 | 0.653635 | 0.606 | add 5-fold CV + early-stopping back. | |\n", "| v7 | N/A | 0.617 | fix the bug in learning rate scheduler | overfitting to train? (n=20) |\n", "| v6 | N/A | 0.566 | add the warm-up learning rate scheduler | **With a bug. Don't use it** |\n", "| v5 | N/A | 0.531 | roll back to v3 | |\n", "| v4 | N/A | 0.573 | add early-stopping | seemed to stop too early with `patience=1` (n=5) |\n", "| v3 | N/A | **0.530** | initial baseline | |" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "nterop": { "id": "2" } }, "source": [ "# Load Libraries and Data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-05-07T23:33:06.447697Z", "start_time": "2021-05-07T23:33:06.421563Z" }, "nterop": { "id": "30" } }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2021-05-08T00:40:33.409410Z", "start_time": "2021-05-08T00:40:32.188406Z" }, "_kg_hide-input": true, "nterop": { "id": "3" } }, "outputs": [], "source": [ "from copy import copy\n", "import joblib\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import os\n", "import pandas as pd\n", "from pathlib import Path\n", "import seaborn as sns\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.model_selection import KFold\n", "import sys\n", "from warnings import simplefilter\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras import Model, Input\n", "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler\n", "from tensorflow.keras.initializers import Constant\n", "from tensorflow.keras.layers import Dense, Embedding, Bidirectional, LSTM, Dropout\n", "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n", "from tensorflow.keras.metrics import RootMeanSquaredError\n", "from tensorflow.keras.utils import to_categorical\n", "from tensorflow.keras.optimizers import Adam\n", "from transformers import TFBertModel, BertConfig, BertTokenizerFast\n", "\n", "simplefilter('ignore')\n", "plt.style.use('fivethirtyeight')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-05-07T23:33:10.974271Z", "start_time": "2021-05-07T23:33:10.869360Z" }, "nterop": { "id": "31" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num GPUs Available: 1\n" ] } ], "source": [ "# limit the GPU memory growth\n", "gpu = tf.config.list_physical_devices('GPU')\n", "print(\"Num GPUs Available: \", len(gpu))\n", "if len(gpu) > 0:\n", " tf.config.experimental.set_memory_growth(gpu[0], True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2021-05-08T00:20:53.916646Z", "start_time": "2021-05-08T00:20:53.681663Z" }, "nterop": { "id": "6" } }, "outputs": [], "source": [ "model_name = 'bert_v9'\n", "\n", "data_dir = Path('../input/commonlitreadabilityprize')\n", "train_file = data_dir / 'train.csv'\n", "test_file = data_dir / 'test.csv'\n", "sample_file = data_dir / 'sample_submission.csv'\n", "\n", "build_dir = Path('../build/')\n", "output_dir = build_dir / model_name\n", "trn_encoded_file = output_dir / 'trn.enc.joblib'\n", "tokenizer_file = output_dir / 'tokenizer.joblib'\n", "val_predict_file = output_dir / f'{model_name}.val.txt'\n", "submission_file = 'submission.csv'\n", "\n", "module_url = \"../input/bert-en-uncased-l24-h1024-a16\"\n", "\n", "id_col = 'id'\n", "target_col = 'target'\n", "text_col = 'excerpt'\n", "\n", "max_len = 205\n", "n_fold = 5\n", "n_est = 2\n", "n_stop = 2\n", "batch_size = 8\n", "seed = 42" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-05-07T23:33:11.380920Z", "start_time": "2021-05-07T23:33:11.252872Z" }, "nterop": { "id": "41" } }, "outputs": [], "source": [ "output_dir.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-05-07T23:33:11.635430Z", "start_time": "2021-05-07T23:33:11.382137Z" }, "nterop": { "id": "7" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(2834, 5) (2834,) (7, 3)\n" ] }, { "data": { "text/html": [ "
\n", " | url_legal | \n", "license | \n", "excerpt | \n", "target | \n", "standard_error | \n", "
---|---|---|---|---|---|
id | \n", "\n", " | \n", " | \n", " | \n", " | \n", " |
c12129c31 | \n", "NaN | \n", "NaN | \n", "When the young people returned to the ballroom... | \n", "-0.340259 | \n", "0.464009 | \n", "
85aa80a4c | \n", "NaN | \n", "NaN | \n", "All through dinner time, Mrs. Fayre was somewh... | \n", "-0.315372 | \n", "0.480805 | \n", "
b69ac6792 | \n", "NaN | \n", "NaN | \n", "As Roger had predicted, the snow departed as q... | \n", "-0.580118 | \n", "0.476676 | \n", "
dd1000b26 | \n", "NaN | \n", "NaN | \n", "And outside before the palace a great garden w... | \n", "-1.054013 | \n", "0.450007 | \n", "
37c1b32fb | \n", "NaN | \n", "NaN | \n", "Once upon a time there were Three Bears who li... | \n", "0.247197 | \n", "0.510845 | \n", "