{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Distilling_BERT_to_a_CNN.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "Ev67EeXyugYI", "colab_type": "text" }, "source": [ "Distilling BERT to simpler models: CNN and Linear\n", "=================================================\n", "\n", "This is a little extension of the work done in _Distilling Task-Specific Knowledge from BERT into Simple Neural Networks_ by Tang et al. 2019. Hopefully this notebook will serve as an easy-to-follow guide to distillation, which is actually really simple. This is based on work I did for [Polecat](polecat.com).\n", "\n", "Tang demonstrates that training a lower-complexity student model to predict a teacher model's output logits is more effective than directly training the student model on the dataset. This is a really neat way of improving performance of smaller models (which are much easier to productionize).\n", "\n", "In the paper Tang uses BERT to train a BiLSTM. One of the suggestions for future work is to explore to what extent even simpler models can benefit from the technique. This notebook does just that - we'll try and use BERT to train a CNN and simple linear model implemented in PyTorch.\n", "\n", "The linear model is the FastText model (Joulin et al. 2016) which normally is an excellent compromise between speed and accuracy. The task is document classification. We wouldn't expect to get near BERT-like accuracy because FastText is a bag-of-words model (it ignores word order, although you can give it n-grams) but it will be interesting to see if we can increase its accuracy at all. \n", "\n", "The CNN is the basic model described by Kim in _Convolutional Neural Networks for Sentence Classification_ (2014). For simplicitly pretrained word embeddings haven't been used, although they would certainly improve performance.\n", "\n", "Let's begin with our dependencies: PyTorch, the great Huggingface transformers library (for a BERT implementation) and other usual suspects." ] }, { "cell_type": "code", "metadata": { "id": "1T8SM4q2SrD8", "colab_type": "code", "outputId": "4e0129ae-4a43-4e27-8152-459e95b06b8f", "colab": { "base_uri": "https://localhost:8080/", "height": 607 } }, "source": [ "!pip install torch transformers pandas tqdm altair joblib sklearn" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.5.0+cu101)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (2.11.0)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (1.0.4)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.41.1)\n", "Requirement already satisfied: altair in /usr/local/lib/python3.6/dist-packages (4.1.0)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (0.15.1)\n", "Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (0.0)\n", "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.18.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n", "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.91)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n", "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n", "Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.43)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)\n", "Requirement already satisfied: tokenizers==0.7.0 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7.0)\n", "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1)\n", "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from altair) (2.11.2)\n", "Requirement already satisfied: toolz in /usr/local/lib/python3.6/dist-packages (from altair) (0.10.0)\n", "Requirement already satisfied: jsonschema in /usr/local/lib/python3.6/dist-packages (from altair) (2.6.0)\n", "Requirement already satisfied: entrypoints in /usr/local/lib/python3.6/dist-packages (from altair) (0.3)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sklearn) (0.22.2.post1)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n", "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->altair) (1.1.1)\n", "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn) (1.4.1)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "tKhjpP3sHljm", "colab_type": "code", "colab": {} }, "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import altair as alt\n", "from pathlib import Path\n", "from joblib import Memory\n", "from sklearn.metrics import f1_score\n", "from tqdm import tqdm\n", "from torch.utils.data import DataLoader, TensorDataset\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, DistilBertForSequenceClassification" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "KOtQVQ9pVGDV", "colab_type": "code", "outputId": "f754b516-b367-492e-8d11-27a802a7c2ee", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "device" ], "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "device(type='cuda')" ] }, "metadata": { "tags": [] }, "execution_count": 3 } ] }, { "cell_type": "markdown", "metadata": { "id": "6ZQZZTqKWxgR", "colab_type": "text" }, "source": [ "Dataset\n", "-------\n", "\n", "We'll use the Amazon review dataset. It is freely available and consists of product reviews with a star rating, and the task is simply to predict the star rating. It's a challenging task.\n", "\n", "First, some data wrangling. I'm afraid this notebook won't run out-of-the-box, because the data and teacher model are too large to distribute." ] }, { "cell_type": "code", "metadata": { "id": "-FrGQvQzS7wj", "colab_type": "code", "colab": {} }, "source": [ "ROOT = Path(\"/mnt/gdrive/My Drive/\")\n", "\n", "if not ROOT.exists():\n", " from google.colab import drive\n", " drive.mount(\"/mnt/gdrive\")\n", "\n", "assert ROOT.exists()\n", "DATA = ROOT / \"data\"\n", "MODELS = ROOT / \"models\"" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "y-PMOCCl1u7X", "colab_type": "code", "colab": {} }, "source": [ "CACHE = ROOT / \"cache/distillation\"\n", "\n", "if not CACHE.exists():\n", " CACHE.mkdir(parents=True)\n", "\n", "memory = Memory(CACHE, verbose=False)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "FCNkvexhTHEP", "colab_type": "code", "colab": {} }, "source": [ "market = \"uk\"\n", "\n", "reviews = (pd.read_csv(DATA / \"amazon\" / f\"amazon_reviews_multilingual_{market.upper()}_v1_00.tsv.gz\",\n", " sep=\"\\t\",\n", " usecols=[\"review_id\", \"star_rating\", \"review_headline\", \"review_body\"],\n", " dtype={\"review_id\": \"string\",\n", " \"star_rating\": \"Int32\",\n", " \"review_headline\": \"string\",\n", " \"review_body\": \"string\"})\n", " .dropna())" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "S6o-g2uR5nvN", "colab_type": "text" }, "source": [ "\n", "We balance the classes and shuffle the dataset. Ideally we should also remove some low-value reviews, e.g. single-word reviews and reviews in other languages. But there are few enough of these to not make much difference as far as this exploration goes." ] }, { "cell_type": "code", "metadata": { "id": "EjO9V73UTK6m", "colab_type": "code", "outputId": "3cf66541-1d3a-4baf-93ad-bdb4d915f6f9", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "MAX_LEN = 50_000\n", "\n", "classes = {1, 2, 3, 4, 5}\n", "class_examples = [reviews[reviews.star_rating == rating] for rating in classes]\n", "\n", "min_len = min(MAX_LEN // len(classes), *[len(c) for c in class_examples])\n", "\n", "balanced_df = pd.concat([c.sample(min_len, random_state=42) for c in class_examples])\n", "\n", "shuffled_df = balanced_df.sample(len(balanced_df))\n", "shuffled_df[\"label\"] = shuffled_df.star_rating.astype(int) - 1\n", "\n", "len(shuffled_df)" ], "execution_count": 7, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "50000" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "code", "metadata": { "id": "MalNs-F4TPUb", "colab_type": "code", "outputId": "d3e0ca95-4005-4fa4-f5ba-6eb0dfaf6b31", "colab": { "base_uri": "https://localhost:8080/", "height": 112 } }, "source": [ "shuffled_df.head(2)" ], "execution_count": 8, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>review_id</th>\n", " <th>star_rating</th>\n", " <th>review_headline</th>\n", " <th>review_body</th>\n", " <th>label</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>626969</th>\n", " <td>R16AH12YPHGU7C</td>\n", " <td>1</td>\n", " <td>No instructions</td>\n", " <td>No instructions, only pictures that you can fi...</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1101179</th>\n", " <td>R1W6Y6B361L24G</td>\n", " <td>3</td>\n", " <td>A Little Slight But Still Entertaining</td>\n", " <td>Although the Bee Gees had included some R+B/so...</td>\n", " <td>2</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " review_id ... label\n", "626969 R16AH12YPHGU7C ... 0\n", "1101179 R1W6Y6B361L24G ... 2\n", "\n", "[2 rows x 5 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "markdown", "metadata": { "id": "w06iuA7559bF", "colab_type": "text" }, "source": [ "Split the data into a training set and a test set." ] }, { "cell_type": "code", "metadata": { "id": "Bf0GbuU1TSrf", "colab_type": "code", "outputId": "42779698-abb3-4529-998c-a42d03a6b674", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "train_frac = 0.8\n", "split_idx = int(train_frac * len(shuffled_df))\n", "\n", "train_df = shuffled_df[:split_idx]\n", "test_df =shuffled_df[split_idx:]\n", "len(train_df), len(test_df)" ], "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(40000, 10000)" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "cell_type": "markdown", "metadata": { "id": "55NhgXuN6B_o", "colab_type": "text" }, "source": [ "Tokenize the text and convert it to PyTorch tensors. We also need two masking vectors for each example as input to BERT." ] }, { "cell_type": "code", "metadata": { "id": "fhxUiqpaWPq8", "colab_type": "code", "colab": {} }, "source": [ "try:\n", " tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-multilingual-cased\")\n", "except NameError:\n", " tokenizer = tokenizer" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "PlKDEQpsFXJX", "colab_type": "text" }, "source": [ "To reduce training time, we will precompute the teacher's predictions. Additionally `joblib` is used to cache the results of this function so less time is wasted during development. This is the reason for the Torch/pickling warnings that occur later on in this notebook." ] }, { "cell_type": "code", "metadata": { "id": "XmQTwaz-VB_0", "colab_type": "code", "colab": {} }, "source": [ "@memory.cache(ignore=[\"teacher\"]) # cache but don't bother serializing the teacher - doesn't change and is slow to pickle\n", "def dataframe_to_dataset(df, teacher):\n", " max_len = 128\n", " features = tokenizer.batch_encode_plus(df.review_body,\n", " max_length=max_len,\n", " pad_to_max_length=True,\n", " return_attention_masks=True,\n", " return_token_type_ids=True,\n", " return_tensors=\"pt\")\n", "\n", " pre_dataset = TensorDataset(features[\"input_ids\"],\n", " features[\"attention_mask\"],\n", " features[\"token_type_ids\"])\n", " \n", " teacher.to(device)\n", " teacher.eval()\n", " teacher_predictions = []\n", " for batch in tqdm(DataLoader(pre_dataset, batch_size=32, shuffle=False)):\n", " batch = tuple([b.to(device) for b in batch])\n", " inputs = {\"input_ids\": batch[0], \"attention_mask\": batch[1]}\n", " if teacher.base_model_prefix == \"bert\":\n", " inputs[\"token_type_ids\"] = batch[2]\n", " with torch.no_grad():\n", " outputs = teacher(**inputs)\n", " teacher_predictions.append(outputs[0].to(torch.device(\"cpu\"))) # put back on CPU\n", "\n", " dataset = TensorDataset(features[\"input_ids\"],\n", " features[\"attention_mask\"],\n", " features[\"token_type_ids\"],\n", " torch.tensor(df.label.astype(\"int\").to_numpy(), dtype=torch.long),\n", " torch.cat(teacher_predictions, axis=0))\n", " return dataset" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "YpfP65ABWq7W", "colab_type": "text" }, "source": [ "Hyperparameters\n", "---------------\n", "\n", "These are more-or-less the default hyperparameters for FastText. The embedding dimension is reduced to 50 to speed up processing slightly.\n", "\n", "Beware the batch size - we're using a batch size of **1** for training the linear model. This has a significant impact on its accuracy, and it's lightweight enough that we can get away with it." ] }, { "cell_type": "code", "metadata": { "id": "os3jL-4jVzii", "colab_type": "code", "colab": {} }, "source": [ "N_EPOCHS = 5\n", "EMBEDDING_DIM = 50\n", "LR = 0.5\n", "BATCH_SIZE = 32\n", "N_LABELS = 5 # num review ratings" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "j4Q4hf5eV5A9", "colab_type": "code", "colab": {} }, "source": [ "padding_idx = tokenizer.vocab[\"[PAD]\"]\n", "n_vocab = len(tokenizer.vocab)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "DZX0nhVkWtts", "colab_type": "text" }, "source": [ "Teacher\n", "-------\n", "\n", "The teacher is actually DistilBERT, rather than BERT. So we are distilling from a distilled model! Ideally the teacher should be BERT-proper so that results are more comparable. But this is running on Google Colab with limited GPU time, so a compromise is necessary.\n", "\n", "I trained this DistilBERT model on the same dataset previously. Later on we'll check its accuracy." ] }, { "cell_type": "code", "metadata": { "id": "66VmdmCjUjVu", "colab_type": "code", "colab": {} }, "source": [ "try:\n", " config = config\n", " teacher = teacher\n", "except NameError:\n", " config = AutoConfig.from_pretrained(\"distilbert-base-multilingual-cased\")\n", " config.num_labels = N_LABELS\n", " teacher = DistilBertForSequenceClassification(config)\n", " teacher.load_state_dict(torch.load(MODELS / \"distilbert_uk_50000.bin\", map_location=device))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "di07LSp6WvXx", "colab_type": "text" }, "source": [ "Students\n", "--------\n", "\n", "This is a simple convolution neural network (CNN) with dropout as per Kim (2014)." ] }, { "cell_type": "code", "metadata": { "id": "5G3Mb6E2TW-v", "colab_type": "code", "colab": {} }, "source": [ "class CNN(nn.Module):\n", "\n", " def __init__(self,\n", " n_vocab,\n", " n_labels,\n", " embedding_dim=50,\n", " n_filters=100,\n", " filter_sizes=[3, 4, 5],\n", " dropout=0.5,\n", " special_chars=[],\n", " pretrained_embeddings=None): # TODO make number of conv layers configurable\n", " super(CNN, self).__init__()\n", " self.n_vocab = n_vocab\n", " self.n_labels = n_labels\n", " self.embedding_dim = embedding_dim\n", " self.n_filters = n_filters\n", " self.filter_sizes = filter_sizes\n", " self.dropout_p = dropout\n", " self.width = len(filter_sizes) * n_filters\n", "\n", " if pretrained_embeddings is not None:\n", " assert n_vocab == pretrained_embeddings.shape[0]\n", " assert embedding_dim == pretrained_embeddings.shape[1]\n", " self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings)\n", " else:\n", " self.embedding = nn.Embedding(n_vocab, embedding_dim)\n", " \n", " self.conv0 = nn.Conv2d(in_channels=1,\n", " out_channels=n_filters,\n", " kernel_size=(filter_sizes[0], embedding_dim))\n", " self.conv1 = nn.Conv2d(in_channels=1,\n", " out_channels=n_filters,\n", " kernel_size=(filter_sizes[1], embedding_dim))\n", " self.conv2 = nn.Conv2d(in_channels=1,\n", " out_channels=n_filters,\n", " kernel_size=(filter_sizes[2], embedding_dim))\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " self.fc = nn.Linear(in_features=self.width, out_features=n_labels)\n", "\n", " for special in special_chars:\n", " self.embedding.weight.data[special] = torch.zeros(embedding_dim)\n", "\n", " def forward(self, input_ids, **kwargs):\n", " \"\"\"Only input ids are required - kwargs are for API compat with BERT.\"\"\"\n", " X = self.embedding(input_ids)\n", " X = X.unsqueeze(1) # add single channel as dim 1\n", " X0 = F.relu(self.conv0(X).squeeze(3))\n", " X1 = F.relu(self.conv1(X).squeeze(3))\n", " X2 = F.relu(self.conv2(X).squeeze(3))\n", " X0 = F.max_pool1d(X0, X0.shape[2]).squeeze(2)\n", " X1 = F.max_pool1d(X1, X1.shape[2]).squeeze(2)\n", " X2 = F.max_pool1d(X2, X2.shape[2]).squeeze(2)\n", " X = torch.cat([X0, X1, X2], dim=1)\n", " X = self.dropout(X)\n", " X = self.fc(X)\n", " return X" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "2NkOWG08Taam", "colab_type": "text" }, "source": [ "This is a faithful implementation of the FastText linear model (Joulin et al. 2016)." ] }, { "cell_type": "code", "metadata": { "id": "cCtKES4bVmeU", "colab_type": "code", "colab": {} }, "source": [ "class LinearModel(nn.Module):\n", " \n", " def __init__(self, n_vocab, n_labels, embedding_dim, padding_idx):\n", " super(LinearModel, self).__init__()\n", " self.embeddings = nn.Embedding(n_vocab, embedding_dim, padding_idx=padding_idx)\n", " self.output = nn.Linear(embedding_dim, n_labels)\n", " with torch.no_grad():\n", " # FastText initializes embeddings with uniform distribution vs normal in PyTorch\n", " self.embeddings.weight.uniform_(to=1.0 / embedding_dim)\n", " self.embeddings.weight[padding_idx] = 0 # but FT doesn't have a padding token\n", " # FastText initializes output with zeros vs some random dist in PyTorch\n", " self.output.weight.zero_()\n", "\n", " def forward(self, input_ids, **kwargs):\n", " \"\"\"Only input ids are required - kwargs are for API compat with BERT.\"\"\"\n", " X = self.embeddings(input_ids)\n", " X = X.mean(dim=1)\n", " X = self.output(X)\n", " return X" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "otPP3--DW_58", "colab_type": "text" }, "source": [ "Training\n", "--------\n", "\n", "This function trains the model for one epoch. If no teacher is provided it uses cross entropy loss (i.e. softmax then NLL) and compares the model predictions to the target label.\n", "\n", "If a teacher is provided then model predictions are compared to the teacher's predictions and MSE loss is used.\n", "\n", "In the paper Tang defines a cost function that is a balance between the two (i.e. $L = \\alpha L_{CE} + (1 - \\alpha L_{MSE})$ but in practice observed that the best value for $\\alpha$ was zero.\n", "\n", "The accuracy on the training set is also output for visibility." ] }, { "cell_type": "code", "metadata": { "id": "xqpDH4XtXAyX", "colab_type": "code", "colab": {} }, "source": [ "def train_epoch(train_iter, model, optim, epoch_num, distil=False):\n", " train_loss = 0\n", " train_acc = 0\n", " y_true = []\n", " y_pred = []\n", " \n", " model.to(device)\n", " model.train()\n", " \n", " if distil:\n", " cost = nn.MSELoss()\n", " else:\n", " cost = nn.CrossEntropyLoss()\n", "\n", " for batch in tqdm(train_iter, total=len(train_iter), desc=f\"Batch progress for epoch {epoch_num}\"):\n", " \n", " batch = tuple([t.to(device) for t in batch])\n", " inputs = {\"input_ids\": batch[0],\n", " \"attention_mask\": batch[1]}\n", " labels = batch[3]\n", "\n", " optim.zero_grad()\n", " output = model(**inputs)\n", "\n", " if distil:\n", " target = batch[4]\n", " else:\n", " target = labels\n", "\n", " batch_loss = cost(output, target)\n", "\n", " # Had some trouble with linear distilled model dying in training\n", " # but since starting to debug it the issue hasn't reoccurred.\n", " # Gradient clipping might help. \n", " if torch.isnan(batch_loss):\n", " print(\"NAN batch loss!\", epoch_num, batch_loss, output, target)\n", "\n", " train_loss += batch_loss.item()\n", "\n", " batch_acc = (output.argmax(1) == labels).sum().item()\n", " train_acc += batch_acc\n", " y_true.extend(labels.tolist())\n", " y_pred.extend(output.argmax(1).tolist())\n", "\n", " batch_loss.backward()\n", " optim.step()\n", "\n", " return train_loss / len(train_iter), train_acc / len(train_iter.dataset), f1_score(y_true, y_pred, average=\"macro\") # classes are already balanced" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "2fTVu9qSZKjA", "colab_type": "code", "colab": {} }, "source": [ "def train_loop(model, optim, train_loader, test_loader, n_epochs=5, sched=None, distil=False):\n", " training_results = {\"epoch\": list(range(n_epochs)),\n", " \"train_loss\": [],\n", " \"train_acc\": [],\n", " \"train_f1_macro\": [],\n", " \"test_loss\": [],\n", " \"test_acc\": [],\n", " \"test_f1_macro\": []}\n", "\n", " model.to(device)\n", "\n", " try:\n", " for i in range(n_epochs):\n", " train_loss, train_acc, train_f1 = train_epoch(train_loader, model, optim, epoch_num=i, distil=distil)\n", " if sched is not None:\n", " sched.step()\n", " test_loss, test_acc, test_f1 = validate(test_loader, model)\n", " training_results[\"train_loss\"].append(train_loss)\n", " training_results[\"train_acc\"].append(train_acc)\n", " training_results[\"train_f1_macro\"].append(train_f1)\n", " training_results[\"test_loss\"].append(test_loss)\n", " training_results[\"test_acc\"].append(test_acc)\n", " training_results[\"test_f1_macro\"].append(test_f1)\n", " except KeyboardInterrupt:\n", " pass\n", "\n", " return pd.DataFrame(training_results)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "LqK2kyVK_ZEO", "colab_type": "text" }, "source": [ "The validation function is similar but in this case there is no option to compare to the teacher's predictions, because that's not the ultimate point of the exercise - at the end of it all we just want a better small model.\n", "\n", "The metrics are _accuracy_ and _macro F1_. Macro F1 will help us understand whether the model is performing similarly across all classes (e.g. a model that always predicts class 2 will have accuracy of 20% but a terrible F1 score)." ] }, { "cell_type": "code", "metadata": { "id": "oYvXJauRnc7B", "colab_type": "code", "colab": {} }, "source": [ "def validate(test_iter, model):\n", " test_acc = 0 \n", " test_loss = 0\n", " y_true = []\n", " y_pred = []\n", "\n", " cost = nn.CrossEntropyLoss()\n", "\n", " model.to(device)\n", " model.eval()\n", "\n", " for batch in tqdm(test_iter, desc=\"Validating\"):\n", " \n", " batch = tuple([t.to(device) for t in batch])\n", " inputs = {\"input_ids\": batch[0],\n", " \"attention_mask\": batch[1],\n", " \"token_type_ids\": batch[2]}\n", " labels = batch[3]\n", "\n", " with torch.no_grad():\n", " output = model(**inputs)\n", " \n", " batch_loss = cost(output, labels)\n", " test_loss += batch_loss.item()\n", " \n", " batch_acc = (output.argmax(1) == labels).sum().item() \n", " test_acc += batch_acc\n", " y_true.extend(labels.tolist())\n", " y_pred.extend(output.argmax(1).tolist())\n", "\n", " return test_loss / len(test_iter), test_acc / len(test_iter.dataset), f1_score(y_true, y_pred, average=\"macro\") # classes are balanced" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "S8zpmz6zGzHe", "colab_type": "text" }, "source": [ "We will construct two models of each architecture; one to train directly and one to train with distillation." ] }, { "cell_type": "code", "metadata": { "id": "3a5_1_rNUZ4I", "colab_type": "code", "colab": {} }, "source": [ "linear_model = LinearModel(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, padding_idx=padding_idx)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "xg42hAkZX-kg", "colab_type": "code", "colab": {} }, "source": [ "linear_model_dist = LinearModel(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, padding_idx=padding_idx)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ksGxZeekUXk0", "colab_type": "code", "colab": {} }, "source": [ "cnn = CNN(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, special_chars=[padding_idx])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "mor5doN_YIl0", "colab_type": "code", "colab": {} }, "source": [ "cnn_dist = CNN(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, special_chars=[padding_idx])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "os1muwk5VOn_", "colab_type": "code", "outputId": "ee589aac-7af7-455d-e646-e46a51e65a3d", "colab": { "base_uri": "https://localhost:8080/", "height": 71 } }, "source": [ "test_loader = DataLoader(dataframe_to_dataset(test_df, teacher), batch_size=BATCH_SIZE, shuffle=False)" ], "execution_count": 24, "outputs": [ { "output_type": "stream", "text": [ "100%|██████████| 313/313 [00:37<00:00, 8.25it/s]\n", "/usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "gxboj94idLIP", "colab_type": "text" }, "source": [ "Sanity check - we expect 20% accuracy in each case. The loss (which is cross-entropy for validation, regardless of the training method) should be about 1.6, i.e. the loss we'd expect at random accuracy. The CNN's might have higher loss because of the regularization (i.e. the dropout)." ] }, { "cell_type": "code", "metadata": { "id": "tUtRBjRpn13o", "colab_type": "code", "outputId": "b863d179-923d-4481-ed55-629cd18d7ea8", "colab": { "base_uri": "https://localhost:8080/", "height": 53 } }, "source": [ "validate(test_loader, linear_model)" ], "execution_count": 25, "outputs": [ { "output_type": "stream", "text": [ "Validating: 100%|██████████| 313/313 [00:00<00:00, 1010.79it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/plain": [ "(1.6121794857537022, 0.1963, 0.06563571010616068)" ] }, "metadata": { "tags": [] }, "execution_count": 25 } ] }, { "cell_type": "code", "metadata": { "id": "Pavg85SzYMsp", "colab_type": "code", "outputId": "ee239336-babf-43fd-ce8b-e97803a33e1d", "colab": { "base_uri": "https://localhost:8080/", "height": 53 } }, "source": [ "validate(test_loader, linear_model_dist)" ], "execution_count": 26, "outputs": [ { "output_type": "stream", "text": [ "Validating: 100%|██████████| 313/313 [00:00<00:00, 1017.86it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/plain": [ "(1.6159831746317708, 0.2011, 0.06697194238614604)" ] }, "metadata": { "tags": [] }, "execution_count": 26 } ] }, { "cell_type": "code", "metadata": { "id": "t84_mQhGdJ-B", "colab_type": "code", "outputId": "b2c1e07f-2eec-4d5b-91a2-7759911fb260", "colab": { "base_uri": "https://localhost:8080/", "height": 53 } }, "source": [ "validate(test_loader, cnn)" ], "execution_count": 27, "outputs": [ { "output_type": "stream", "text": [ "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.14it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/plain": [ "(1.9059412772663105, 0.2051, 0.06808298755186722)" ] }, "metadata": { "tags": [] }, "execution_count": 27 } ] }, { "cell_type": "code", "metadata": { "id": "gAQT2-zUYP40", "colab_type": "code", "outputId": "f78ccf34-8345-4a68-caa0-8067c729828c", "colab": { "base_uri": "https://localhost:8080/", "height": 53 } }, "source": [ "validate(test_loader, cnn_dist)" ], "execution_count": 28, "outputs": [ { "output_type": "stream", "text": [ "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.44it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/plain": [ "(1.8158414051555598, 0.2162, 0.1188892656541893)" ] }, "metadata": { "tags": [] }, "execution_count": 28 } ] }, { "cell_type": "markdown", "metadata": { "id": "TzgzIprndMye", "colab_type": "text" }, "source": [ "Training\n", "--------\n", "\n", "We use SGD, a batch size of 1 and a linearly decreasing learning rate because this is empirically best for the linear model (see Joulin et al.). So the results we get for this model should be very close to what would be achieved using the [fasttext](https://fasttext.cc) library. " ] }, { "cell_type": "code", "metadata": { "id": "ydxYTUEMcmyo", "colab_type": "code", "outputId": "112e1471-fc4e-4214-cf3a-f2c93cddb500", "colab": { "base_uri": "https://localhost:8080/", "height": 214 } }, "source": [ "train_loader = DataLoader(dataframe_to_dataset(train_df, teacher), batch_size=1, shuffle=False) # optimal training for the linear model" ], "execution_count": 29, "outputs": [ { "output_type": "stream", "text": [ "100%|██████████| 1250/1250 [02:30<00:00, 8.29it/s]\n", "/usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n", "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: UserWarning: Persisting input arguments took 0.91s to run.\n", "If this happens often in your code, it can cause performance problems \n", "(results will be correct in all cases). \n", "The reason for this is probably some large input arguments for a wrapped\n", " function (e.g. large strings).\n", "THIS IS A JOBLIB ISSUE. If you can, kindly provide the joblib's team with an\n", " example so that they can fix the problem.\n", " \"\"\"Entry point for launching an IPython kernel.\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "bf8Jd5TWY78z", "colab_type": "code", "outputId": "6ab01a5d-41c4-4714-9e11-affce665199b", "colab": { "base_uri": "https://localhost:8080/", "height": 385 } }, "source": [ "optim = torch.optim.SGD(linear_model.parameters(), lr=LR)\n", "sched = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=0.5)\n", "\n", "linear_model_train_results = train_loop(linear_model, optim, train_loader, test_loader, N_EPOCHS, sched, distil=False)\n", "linear_model_train_results[\"model\"] = \"Linear\"\n", "linear_model_train_results" ], "execution_count": 30, "outputs": [ { "output_type": "stream", "text": [ "Batch progress for epoch 0: 100%|██████████| 40000/40000 [01:08<00:00, 585.80it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 995.07it/s] \n", "Batch progress for epoch 1: 100%|██████████| 40000/40000 [01:07<00:00, 591.82it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 1012.24it/s]\n", "Batch progress for epoch 2: 100%|██████████| 40000/40000 [01:07<00:00, 594.67it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 972.25it/s]\n", "Batch progress for epoch 3: 100%|██████████| 40000/40000 [01:07<00:00, 594.98it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 972.59it/s]\n", "Batch progress for epoch 4: 100%|██████████| 40000/40000 [01:07<00:00, 591.64it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 978.57it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>train_acc</th>\n", " <th>train_f1_macro</th>\n", " <th>test_loss</th>\n", " <th>test_acc</th>\n", " <th>test_f1_macro</th>\n", " <th>model</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>1.517863</td>\n", " <td>0.327375</td>\n", " <td>0.323505</td>\n", " <td>1.503577</td>\n", " <td>0.3584</td>\n", " <td>0.280098</td>\n", " <td>Linear</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.286499</td>\n", " <td>0.443450</td>\n", " <td>0.436256</td>\n", " <td>1.385932</td>\n", " <td>0.4166</td>\n", " <td>0.350168</td>\n", " <td>Linear</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.202731</td>\n", " <td>0.488375</td>\n", " <td>0.480658</td>\n", " <td>1.331111</td>\n", " <td>0.4408</td>\n", " <td>0.391589</td>\n", " <td>Linear</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>1.157016</td>\n", " <td>0.516250</td>\n", " <td>0.508555</td>\n", " <td>1.294666</td>\n", " <td>0.4603</td>\n", " <td>0.427882</td>\n", " <td>Linear</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.131102</td>\n", " <td>0.529975</td>\n", " <td>0.522605</td>\n", " <td>1.267536</td>\n", " <td>0.4727</td>\n", " <td>0.451547</td>\n", " <td>Linear</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch train_loss train_acc ... test_acc test_f1_macro model\n", "0 0 1.517863 0.327375 ... 0.3584 0.280098 Linear\n", "1 1 1.286499 0.443450 ... 0.4166 0.350168 Linear\n", "2 2 1.202731 0.488375 ... 0.4408 0.391589 Linear\n", "3 3 1.157016 0.516250 ... 0.4603 0.427882 Linear\n", "4 4 1.131102 0.529975 ... 0.4727 0.451547 Linear\n", "\n", "[5 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 30 } ] }, { "cell_type": "markdown", "metadata": { "id": "xcr_xZg6Ae98", "colab_type": "text" }, "source": [ "A second training loop for the linear model that is trained via distillation." ] }, { "cell_type": "code", "metadata": { "id": "clbIp15Km536", "colab_type": "code", "outputId": "63d167b0-4dc3-4574-8788-bf1601984975", "colab": { "base_uri": "https://localhost:8080/", "height": 385 } }, "source": [ "optim = torch.optim.SGD(linear_model_dist.parameters(), lr=LR)\n", "sched = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=0.5)\n", "\n", "linear_model_dist_train_results = train_loop(linear_model_dist, optim, train_loader, test_loader, N_EPOCHS, sched, distil=True)\n", "linear_model_dist_train_results[\"model\"] = \"Linear (distilled)\"\n", "linear_model_dist_train_results" ], "execution_count": 31, "outputs": [ { "output_type": "stream", "text": [ "Batch progress for epoch 0: 100%|██████████| 40000/40000 [01:06<00:00, 601.33it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 999.62it/s] \n", "Batch progress for epoch 1: 100%|██████████| 40000/40000 [01:06<00:00, 600.73it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 998.13it/s] \n", "Batch progress for epoch 2: 100%|██████████| 40000/40000 [01:06<00:00, 600.61it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 977.32it/s]\n", "Batch progress for epoch 3: 100%|██████████| 40000/40000 [01:06<00:00, 599.38it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 1006.41it/s]\n", "Batch progress for epoch 4: 100%|██████████| 40000/40000 [01:06<00:00, 598.48it/s]\n", "Validating: 100%|██████████| 313/313 [00:00<00:00, 981.18it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>train_acc</th>\n", " <th>train_f1_macro</th>\n", " <th>test_loss</th>\n", " <th>test_acc</th>\n", " <th>test_f1_macro</th>\n", " <th>model</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>2.156607</td>\n", " <td>0.364000</td>\n", " <td>0.367437</td>\n", " <td>1.315366</td>\n", " <td>0.4376</td>\n", " <td>0.431926</td>\n", " <td>Linear (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.915474</td>\n", " <td>0.439525</td>\n", " <td>0.443687</td>\n", " <td>1.281477</td>\n", " <td>0.4598</td>\n", " <td>0.464885</td>\n", " <td>Linear (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>3.414282</td>\n", " <td>0.442000</td>\n", " <td>0.446172</td>\n", " <td>1.294624</td>\n", " <td>0.4524</td>\n", " <td>0.456491</td>\n", " <td>Linear (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>1.905386</td>\n", " <td>0.453650</td>\n", " <td>0.458064</td>\n", " <td>1.275777</td>\n", " <td>0.4533</td>\n", " <td>0.457361</td>\n", " <td>Linear (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.464136</td>\n", " <td>0.478500</td>\n", " <td>0.482614</td>\n", " <td>1.259447</td>\n", " <td>0.4655</td>\n", " <td>0.470224</td>\n", " <td>Linear (distilled)</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch train_loss train_acc ... test_acc test_f1_macro model\n", "0 0 2.156607 0.364000 ... 0.4376 0.431926 Linear (distilled)\n", "1 1 1.915474 0.439525 ... 0.4598 0.464885 Linear (distilled)\n", "2 2 3.414282 0.442000 ... 0.4524 0.456491 Linear (distilled)\n", "3 3 1.905386 0.453650 ... 0.4533 0.457361 Linear (distilled)\n", "4 4 1.464136 0.478500 ... 0.4655 0.470224 Linear (distilled)\n", "\n", "[5 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 31 } ] }, { "cell_type": "markdown", "metadata": { "id": "T1DLajiEAnFU", "colab_type": "text" }, "source": [ "It's interesting to see that the student learned noticeably faster than the directly-trained model (look at the `train_acc` and `test_acc` columns). It also performs _slightly_ better on the test set, but this may not be a significant result.\n", "\n", "Note that you cannot directly compare the training loss, remember these are from different loss functions." ] }, { "cell_type": "markdown", "metadata": { "id": "3NA_nKwWHHY2", "colab_type": "text" }, "source": [ "For the CNN we use Adam and a more conventional batch size to speed up training - it's a more complicated architecture." ] }, { "cell_type": "code", "metadata": { "id": "Rq-Twa_lbgmp", "colab_type": "code", "colab": {} }, "source": [ "train_loader = DataLoader(dataframe_to_dataset(train_df, teacher), batch_size=BATCH_SIZE, shuffle=False) # optimal training for the cnn" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6qHUwgsSaB4k", "colab_type": "code", "outputId": "ec9e5c10-0739-4bb7-8d5e-97a644d20c4b", "colab": { "base_uri": "https://localhost:8080/", "height": 385 } }, "source": [ "optim = torch.optim.Adam(cnn.parameters())\n", "\n", "cnn_train_results = train_loop(cnn, optim, train_loader, test_loader, n_epochs=N_EPOCHS, sched=None, distil=False)\n", "cnn_train_results[\"model\"] = \"CNN\"\n", "cnn_train_results" ], "execution_count": 33, "outputs": [ { "output_type": "stream", "text": [ "Batch progress for epoch 0: 100%|██████████| 1250/1250 [00:27<00:00, 45.88it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.66it/s]\n", "Batch progress for epoch 1: 100%|██████████| 1250/1250 [00:27<00:00, 45.87it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.49it/s]\n", "Batch progress for epoch 2: 100%|██████████| 1250/1250 [00:27<00:00, 45.86it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.56it/s]\n", "Batch progress for epoch 3: 100%|██████████| 1250/1250 [00:27<00:00, 45.92it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.50it/s]\n", "Batch progress for epoch 4: 100%|██████████| 1250/1250 [00:27<00:00, 45.83it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.60it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>train_acc</th>\n", " <th>train_f1_macro</th>\n", " <th>test_loss</th>\n", " <th>test_acc</th>\n", " <th>test_f1_macro</th>\n", " <th>model</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>1.556353</td>\n", " <td>0.308800</td>\n", " <td>0.305734</td>\n", " <td>1.364729</td>\n", " <td>0.4036</td>\n", " <td>0.385092</td>\n", " <td>CNN</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.374294</td>\n", " <td>0.396175</td>\n", " <td>0.390798</td>\n", " <td>1.276479</td>\n", " <td>0.4454</td>\n", " <td>0.444202</td>\n", " <td>CNN</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.286066</td>\n", " <td>0.438800</td>\n", " <td>0.433470</td>\n", " <td>1.237772</td>\n", " <td>0.4617</td>\n", " <td>0.454582</td>\n", " <td>CNN</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>1.213118</td>\n", " <td>0.471425</td>\n", " <td>0.466652</td>\n", " <td>1.226823</td>\n", " <td>0.4673</td>\n", " <td>0.454492</td>\n", " <td>CNN</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>1.154105</td>\n", " <td>0.501875</td>\n", " <td>0.497450</td>\n", " <td>1.224071</td>\n", " <td>0.4739</td>\n", " <td>0.462970</td>\n", " <td>CNN</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch train_loss train_acc ... test_acc test_f1_macro model\n", "0 0 1.556353 0.308800 ... 0.4036 0.385092 CNN\n", "1 1 1.374294 0.396175 ... 0.4454 0.444202 CNN\n", "2 2 1.286066 0.438800 ... 0.4617 0.454582 CNN\n", "3 3 1.213118 0.471425 ... 0.4673 0.454492 CNN\n", "4 4 1.154105 0.501875 ... 0.4739 0.462970 CNN\n", "\n", "[5 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 33 } ] }, { "cell_type": "code", "metadata": { "id": "DKb1QzoLcSMf", "colab_type": "code", "outputId": "b16f5cb9-1617-4df4-98ce-018253e03955", "colab": { "base_uri": "https://localhost:8080/", "height": 385 } }, "source": [ "optim = torch.optim.Adam(cnn_dist.parameters())\n", "\n", "cnn_dist_train_results = train_loop(cnn_dist, optim, train_loader, test_loader, n_epochs=N_EPOCHS, sched=None, distil=True)\n", "cnn_dist_train_results[\"model\"] = \"CNN (distilled)\"\n", "cnn_dist_train_results" ], "execution_count": 34, "outputs": [ { "output_type": "stream", "text": [ "Batch progress for epoch 0: 100%|██████████| 1250/1250 [00:27<00:00, 46.02it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.67it/s]\n", "Batch progress for epoch 1: 100%|██████████| 1250/1250 [00:27<00:00, 46.02it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.51it/s]\n", "Batch progress for epoch 2: 100%|██████████| 1250/1250 [00:27<00:00, 45.76it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.38it/s]\n", "Batch progress for epoch 3: 100%|██████████| 1250/1250 [00:27<00:00, 45.89it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.50it/s]\n", "Batch progress for epoch 4: 100%|██████████| 1250/1250 [00:27<00:00, 45.98it/s]\n", "Validating: 100%|██████████| 313/313 [00:05<00:00, 54.49it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epoch</th>\n", " <th>train_loss</th>\n", " <th>train_acc</th>\n", " <th>train_f1_macro</th>\n", " <th>test_loss</th>\n", " <th>test_acc</th>\n", " <th>test_f1_macro</th>\n", " <th>model</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>2.292530</td>\n", " <td>0.331050</td>\n", " <td>0.334280</td>\n", " <td>1.298172</td>\n", " <td>0.4220</td>\n", " <td>0.427988</td>\n", " <td>CNN (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1</td>\n", " <td>1.490734</td>\n", " <td>0.416750</td>\n", " <td>0.421440</td>\n", " <td>1.236959</td>\n", " <td>0.4490</td>\n", " <td>0.455647</td>\n", " <td>CNN (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2</td>\n", " <td>1.209579</td>\n", " <td>0.451575</td>\n", " <td>0.455322</td>\n", " <td>1.210864</td>\n", " <td>0.4641</td>\n", " <td>0.470251</td>\n", " <td>CNN (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3</td>\n", " <td>1.033341</td>\n", " <td>0.472825</td>\n", " <td>0.475761</td>\n", " <td>1.181134</td>\n", " <td>0.4880</td>\n", " <td>0.491456</td>\n", " <td>CNN (distilled)</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4</td>\n", " <td>0.919759</td>\n", " <td>0.484475</td>\n", " <td>0.486859</td>\n", " <td>1.170595</td>\n", " <td>0.4903</td>\n", " <td>0.493316</td>\n", " <td>CNN (distilled)</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " epoch train_loss train_acc ... test_acc test_f1_macro model\n", "0 0 2.292530 0.331050 ... 0.4220 0.427988 CNN (distilled)\n", "1 1 1.490734 0.416750 ... 0.4490 0.455647 CNN (distilled)\n", "2 2 1.209579 0.451575 ... 0.4641 0.470251 CNN (distilled)\n", "3 3 1.033341 0.472825 ... 0.4880 0.491456 CNN (distilled)\n", "4 4 0.919759 0.484475 ... 0.4903 0.493316 CNN (distilled)\n", "\n", "[5 rows x 8 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 34 } ] }, { "cell_type": "markdown", "metadata": { "id": "ZHQBr5wLHu8h", "colab_type": "text" }, "source": [ "Again the distilled CNN has has achieved greater progress in the early epochs, and again the test accuracy is greater for the distilled model.\n" ] }, { "cell_type": "code", "metadata": { "id": "hFGJUbhHdAit", "colab_type": "code", "colab": {} }, "source": [ "training_results = pd.concat([linear_model_train_results,\n", " linear_model_dist_train_results,\n", " cnn_train_results,\n", " cnn_dist_train_results])" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DmU4KYQYOCUS", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "outputId": "85698ed3-205c-4dcb-ec7d-5a699621e864" }, "source": [ "training_results.query(\"epoch == 4\").set_index(\"model\")[[\"test_acc\", \"test_f1_macro\"]]" ], "execution_count": 37, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>test_acc</th>\n", " <th>test_f1_macro</th>\n", " </tr>\n", " <tr>\n", " <th>model</th>\n", " <th></th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>Linear</th>\n", " <td>0.4727</td>\n", " <td>0.451547</td>\n", " </tr>\n", " <tr>\n", " <th>Linear (distilled)</th>\n", " <td>0.4655</td>\n", " <td>0.470224</td>\n", " </tr>\n", " <tr>\n", " <th>CNN</th>\n", " <td>0.4739</td>\n", " <td>0.462970</td>\n", " </tr>\n", " <tr>\n", " <th>CNN (distilled)</th>\n", " <td>0.4903</td>\n", " <td>0.493316</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " test_acc test_f1_macro\n", "model \n", "Linear 0.4727 0.451547\n", "Linear (distilled) 0.4655 0.470224\n", "CNN 0.4739 0.462970\n", "CNN (distilled) 0.4903 0.493316" ] }, "metadata": { "tags": [] }, "execution_count": 37 } ] }, { "cell_type": "code", "metadata": { "id": "DX6TEmNxOA6D", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 368 }, "outputId": "dd11bf87-f147-44cf-8460-2f600394c247" }, "source": [ "alt.Chart(training_results).mark_line().encode(\n", " x=\"epoch:Q\",\n", " y=\"test_f1_macro:Q\",\n", " color=\"model\"\n", ")" ], "execution_count": 38, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "alt.Chart(...)" ], "text/html": [ "\n", "<div id=\"altair-viz-5139cbff80934f85a38d4f2e354d48b3\"></div>\n", "<script type=\"text/javascript\">\n", " (function(spec, embedOpt){\n", " let outputDiv = document.currentScript.previousElementSibling;\n", " if (outputDiv.id !== \"altair-viz-5139cbff80934f85a38d4f2e354d48b3\") {\n", " outputDiv = document.getElementById(\"altair-viz-5139cbff80934f85a38d4f2e354d48b3\");\n", " }\n", " const paths = {\n", " \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n", " \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n", " \"vega-lite\": \"https://cdn.jsdelivr.net/npm//vega-lite@4.8.1?noext\",\n", " \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n", " };\n", "\n", " function loadScript(lib) {\n", " return new Promise(function(resolve, reject) {\n", " var s = document.createElement('script');\n", " s.src = paths[lib];\n", " s.async = true;\n", " s.onload = () => resolve(paths[lib]);\n", " s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n", " document.getElementsByTagName(\"head\")[0].appendChild(s);\n", " });\n", " }\n", "\n", " function showError(err) {\n", " outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n", " throw err;\n", " }\n", "\n", " function displayChart(vegaEmbed) {\n", " vegaEmbed(outputDiv, spec, embedOpt)\n", " .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n", " }\n", "\n", " if(typeof define === \"function\" && define.amd) {\n", " requirejs.config({paths});\n", " require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n", " } else if (typeof vegaEmbed === \"function\") {\n", " displayChart(vegaEmbed);\n", " } else {\n", " loadScript(\"vega\")\n", " .then(() => loadScript(\"vega-lite\"))\n", " .then(() => loadScript(\"vega-embed\"))\n", " .catch(showError)\n", " .then(() => displayChart(vegaEmbed));\n", " }\n", " })({\"config\": {\"view\": {\"continuousWidth\": 400, \"continuousHeight\": 300}}, \"data\": {\"name\": \"data-268ed0945b86b0e4461eb4f0b93e3443\"}, \"mark\": \"line\", \"encoding\": {\"color\": {\"type\": \"nominal\", \"field\": \"model\"}, \"x\": {\"type\": \"quantitative\", \"field\": \"epoch\"}, \"y\": {\"type\": \"quantitative\", \"field\": \"test_f1_macro\"}}, \"$schema\": \"https://vega.github.io/schema/vega-lite/v4.8.1.json\", \"datasets\": {\"data-268ed0945b86b0e4461eb4f0b93e3443\": [{\"epoch\": 0, \"train_loss\": 1.5178632556349039, \"train_acc\": 0.327375, \"train_f1_macro\": 0.3235048916298362, \"test_loss\": 1.503576754761961, \"test_acc\": 0.3584, \"test_f1_macro\": 0.2800984212687457, \"model\": \"Linear\"}, {\"epoch\": 1, \"train_loss\": 1.2864991627186537, \"train_acc\": 0.44345, \"train_f1_macro\": 0.43625562111281885, \"test_loss\": 1.385931742457917, \"test_acc\": 0.4166, \"test_f1_macro\": 0.3501675545506183, \"model\": \"Linear\"}, {\"epoch\": 2, \"train_loss\": 1.2027310057729483, \"train_acc\": 0.488375, \"train_f1_macro\": 0.48065810876467285, \"test_loss\": 1.3311107347186761, \"test_acc\": 0.4408, \"test_f1_macro\": 0.3915885148604006, \"model\": \"Linear\"}, {\"epoch\": 3, \"train_loss\": 1.157016348619759, \"train_acc\": 0.51625, \"train_f1_macro\": 0.5085547478932909, \"test_loss\": 1.2946662011618812, \"test_acc\": 0.4603, \"test_f1_macro\": 0.4278821634168263, \"model\": \"Linear\"}, {\"epoch\": 4, \"train_loss\": 1.1311019829526543, \"train_acc\": 0.529975, \"train_f1_macro\": 0.5226048899396063, \"test_loss\": 1.2675358894914865, \"test_acc\": 0.4727, \"test_f1_macro\": 0.45154740769843543, \"model\": \"Linear\"}, {\"epoch\": 0, \"train_loss\": 2.1566071700228844, \"train_acc\": 0.364, \"train_f1_macro\": 0.3674369253282151, \"test_loss\": 1.3153660861067116, \"test_acc\": 0.4376, \"test_f1_macro\": 0.4319263895799342, \"model\": \"Linear (distilled)\"}, {\"epoch\": 1, \"train_loss\": 1.9154739793703048, \"train_acc\": 0.439525, \"train_f1_macro\": 0.44368696049683887, \"test_loss\": 1.2814770217139881, \"test_acc\": 0.4598, \"test_f1_macro\": 0.4648852246938387, \"model\": \"Linear (distilled)\"}, {\"epoch\": 2, \"train_loss\": 3.414282215322592, \"train_acc\": 0.442, \"train_f1_macro\": 0.44617182803340566, \"test_loss\": 1.2946236430646512, \"test_acc\": 0.4524, \"test_f1_macro\": 0.45649070996924584, \"model\": \"Linear (distilled)\"}, {\"epoch\": 3, \"train_loss\": 1.905385964841643, \"train_acc\": 0.45365, \"train_f1_macro\": 0.4580635130038576, \"test_loss\": 1.275777078093812, \"test_acc\": 0.4533, \"test_f1_macro\": 0.45736063357327456, \"model\": \"Linear (distilled)\"}, {\"epoch\": 4, \"train_loss\": 1.4641360079715378, \"train_acc\": 0.4785, \"train_f1_macro\": 0.48261401183505626, \"test_loss\": 1.2594467201552833, \"test_acc\": 0.4655, \"test_f1_macro\": 0.4702240057533781, \"model\": \"Linear (distilled)\"}, {\"epoch\": 0, \"train_loss\": 1.5563534006118775, \"train_acc\": 0.3088, \"train_f1_macro\": 0.3057337455111319, \"test_loss\": 1.364729322945348, \"test_acc\": 0.4036, \"test_f1_macro\": 0.3850920591715258, \"model\": \"CNN\"}, {\"epoch\": 1, \"train_loss\": 1.3742936957359313, \"train_acc\": 0.396175, \"train_f1_macro\": 0.3907976960991701, \"test_loss\": 1.2764788667995708, \"test_acc\": 0.4454, \"test_f1_macro\": 0.4442024377139191, \"model\": \"CNN\"}, {\"epoch\": 2, \"train_loss\": 1.28606593170166, \"train_acc\": 0.4388, \"train_f1_macro\": 0.43347004490924823, \"test_loss\": 1.237771660184708, \"test_acc\": 0.4617, \"test_f1_macro\": 0.45458204895491044, \"model\": \"CNN\"}, {\"epoch\": 3, \"train_loss\": 1.213117704153061, \"train_acc\": 0.471425, \"train_f1_macro\": 0.4666523422860749, \"test_loss\": 1.226823113215998, \"test_acc\": 0.4673, \"test_f1_macro\": 0.4544919195998983, \"model\": \"CNN\"}, {\"epoch\": 4, \"train_loss\": 1.154104671573639, \"train_acc\": 0.501875, \"train_f1_macro\": 0.4974503231466116, \"test_loss\": 1.2240713857614194, \"test_acc\": 0.4739, \"test_f1_macro\": 0.46297035197820496, \"model\": \"CNN\"}, {\"epoch\": 0, \"train_loss\": 2.2925297117233274, \"train_acc\": 0.33105, \"train_f1_macro\": 0.3342796380373244, \"test_loss\": 1.2981718089252996, \"test_acc\": 0.422, \"test_f1_macro\": 0.42798823302950384, \"model\": \"CNN (distilled)\"}, {\"epoch\": 1, \"train_loss\": 1.4907343217372895, \"train_acc\": 0.41675, \"train_f1_macro\": 0.42144034919257195, \"test_loss\": 1.2369593903660394, \"test_acc\": 0.449, \"test_f1_macro\": 0.45564661173380216, \"model\": \"CNN (distilled)\"}, {\"epoch\": 2, \"train_loss\": 1.2095785547733306, \"train_acc\": 0.451575, \"train_f1_macro\": 0.455322159487708, \"test_loss\": 1.2108639305392013, \"test_acc\": 0.4641, \"test_f1_macro\": 0.4702512690557981, \"model\": \"CNN (distilled)\"}, {\"epoch\": 3, \"train_loss\": 1.0333412601947785, \"train_acc\": 0.472825, \"train_f1_macro\": 0.47576149267891665, \"test_loss\": 1.181133770904602, \"test_acc\": 0.488, \"test_f1_macro\": 0.4914556919175155, \"model\": \"CNN (distilled)\"}, {\"epoch\": 4, \"train_loss\": 0.9197593289613724, \"train_acc\": 0.484475, \"train_f1_macro\": 0.4868588211765947, \"test_loss\": 1.170595273803979, \"test_acc\": 0.4903, \"test_f1_macro\": 0.49331588896128026, \"model\": \"CNN (distilled)\"}]}}, {\"mode\": \"vega-lite\"});\n", "</script>" ] }, "metadata": { "tags": [] }, "execution_count": 38 } ] }, { "cell_type": "markdown", "metadata": { "id": "0ng30VxDa-rR", "colab_type": "text" }, "source": [ "The distilled models learned faster (reached higher accuracy at lower epochs) and achieved slightly better test accuracy.\n", "\n", "We would need to check whether that result is significant however. Because the data is multinomial and the metric is F1, we could either use the bootstrap method or follow the approach in _A Bayesian Interpretation of the Confusion Matrix_ (Caelen 2017).\n", "\n", "For another interesting comparison, what can the teacher achieve?\n", "\n", "This isn't quite fair because the teacher was not trained on the same splits of this dataset. But it's a large dataset and the likely proportion of the teacher's training data in this test set is low.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "OP-97NQElY9s", "colab_type": "code", "outputId": "eab2cd88-a4ea-4213-d756-033f96071b6b", "colab": { "base_uri": "https://localhost:8080/", "height": 53 } }, "source": [ "teacher.to(device)\n", "teacher.eval()\n", "\n", "teacher_test_acc5 = []\n", "for batch_num, batch in enumerate(tqdm(test_loader)):\n", " batch = tuple([t.to(device) for t in batch])\n", " inputs = {\"input_ids\": batch[0],\n", " \"attention_mask\": batch[1]}\n", " if teacher.base_model_prefix == \"bert\":\n", " inputs[\"token_type_ids\"]: batch[2]\n", " labels = batch[3]\n", " with torch.no_grad():\n", " logits = teacher(**inputs)[0]\n", " probs = torch.softmax(logits, dim=1)\n", " preds_5class = probs.argmax(dim=1)\n", " acc_5class = (preds_5class == labels).sum().item() / len(batch[0])\n", " teacher_test_acc5.append(acc_5class)\n", " \n", "np.mean(teacher_test_acc5)" ], "execution_count": 39, "outputs": [ { "output_type": "stream", "text": [ "100%|██████████| 313/313 [00:38<00:00, 8.11it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/plain": [ "0.6168130990415336" ] }, "metadata": { "tags": [] }, "execution_count": 39 } ] }, { "cell_type": "markdown", "metadata": { "id": "adBkrb-UCj5c", "colab_type": "text" }, "source": [ "(Really brief) Discussion\n", "-------------------------\n", "\n", "So neither model really got close to the teacher's accuracy, but we do see a (potentially significant) improvement in accuracy between the distilled and directly-trained models.\n", "\n", "It's also very interesting to see that the students converged faster. This supports Tang's suggestion that the information about prediction uncertainty is valuable, and that this even outweighs the error from the teacher's inaccurate predictions.\n", "\n", "We might get better results if we implement the data augmentation that Tang suggests. We could also probably do better with a more complex student - you can see that in this [NLP Town blog post](https://www.nlp.town/blog/distilling-bert/), which inspired me to try this. NLP Town trained spaCy's \"ensemble\" classifier, which is a more sophisticated CNN than this and would be expected to perform better." ] }, { "cell_type": "code", "metadata": { "id": "KApe3Xvlozqj", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }