{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "07383deb", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "\n", "<style>\n", " html {\n", " font-size: 18px !important;\n", " }\n", "\n", " body {\n", " background-color: #FFF !important;\n", " font-weight: 1rem;\n", " font-family: 'Source Sans Pro', \"Helvetica Neue\", Helvetica, Arial, sans-serif;\n", " }\n", "\n", " body .notebook-app {\n", " background-color: #FFF !important;\n", " }\n", "\n", " #header {\n", " box-shadow: none !important;\n", " }\n", "\n", " #notebook {\n", " padding-top: 0px;\n", " }\n", "\n", " #notebook-container {\n", " box-shadow: none;\n", " -webkit-box-shadow: none;\n", " padding: 10px;\n", " }\n", "\n", " div.cell {\n", " width: 1000px;\n", " margin-left: 0% !important;\n", " margin-right: auto;\n", " }\n", "\n", " div.cell.selected {\n", " border: 1px dashed #CCCCCC;\n", " }\n", "\n", " .edit_mode div.cell.selected {\n", " border: 1px dashed #828282;\n", " }\n", "\n", " div.output_wrapper {\n", " margin-top: 8px;\n", " }\n", "\n", " a {\n", " color: #383838;\n", " }\n", "\n", " code,\n", " kbd,\n", " pre,\n", " samp {\n", " font-family: 'Menlo', monospace !important;\n", " font-size: 0.75rem !important;\n", " }\n", "\n", " h1 {\n", " font-size: 2rem !important;\n", " font-weight: 500 !important;\n", " letter-spacing: 3px !important;\n", " text-transform: uppercase !important;\n", " }\n", "\n", " h2 {\n", " font-size: 1.8rem !important;\n", " font-weight: 400 !important;\n", " letter-spacing: 3px !important;\n", " text-transform: none !important;\n", " }\n", "\n", " h3 {\n", " font-size: 1.5rem !important;\n", " font-weight: 400 !important;\n", " font-style: italic !important;\n", " display: block !important;\n", " }\n", "\n", " h4,\n", " h5,\n", " h6 {\n", " font-size: 1rem !important;\n", " font-weight: 400 !important;\n", " display: block !important;\n", " }\n", "\n", " .prompt {\n", " font-family: 'Menlo', monospace !important;\n", " font-size: 0.75rem;\n", " text-align: right;\n", " line-height: 1.21429rem;\n", " }\n", "\n", " /* INTRO PAGE */\n", "\n", " .toolbar_info,\n", " .list-container {\n", " ;\n", " }\n", " /* NOTEBOOK */\n", "\n", " div#header-container {\n", " display: none !important;\n", " }\n", "\n", " div#notebook {\n", " border-top: none;\n", " font-size: 1rem;\n", " }\n", "\n", " div.input_prompt {\n", " color: #C74483;\n", " }\n", "\n", " .code_cell div.input_prompt:after,\n", " div.output_prompt:after {\n", " content: '\\25b6';\n", " }\n", "\n", " div.output_prompt {\n", " color: #2B88D9;\n", " }\n", "\n", " div.input_area {\n", " border-radius: 0px;\n", " border: 1px solid #d8d8d8;\n", " }\n", "\n", " div.output_area pre {\n", " font-weight: normal;\n", " }\n", "\n", " div.output_subarea {\n", " font-weight: normal;\n", " }\n", "\n", " .rendered_html pre,\n", " .rendered_html table,\n", " .rendered_html th,\n", " .rendered_html tr,\n", " .rendered_html td {\n", " border: 1px #828282 solid;\n", " font-size: 0.75rem;\n", " font-family: 'Menlo', monospace;\n", " }\n", "\n", " .rendered_html th,\n", " .rendered_html tr,\n", " .rendered_html td {\n", " padding: 5px 10px;\n", " }\n", "\n", " .rendered_html th {\n", " font-weight: normal;\n", " background: #f8f8f8;\n", " }\n", "\n", " a:link{\n", " font-weight: bold;\n", " color:#447adb;\n", " }\n", " a:visited{\n", " font-weight: bold;\n", " color: #1d3b84;\n", " }\n", " a:hover{\n", " font-weight: bold;\n", " color: #1d3b84;\n", " }\n", " a:focus{\n", " font-weight: bold;\n", " color:#447adb;\n", " }\n", " a:active{\n", " font-weight: bold;\n", " color:#447adb;\n", " }\n", " .rendered_html :link {\n", " text-decoration: underline; \n", " }\n", "\n", " div.output_html {\n", " font-weight: 1rem;\n", " font-family: 'Source Sans Pro', \"Helvetica Neue\", Helvetica, Arial, sans-serif;\n", " }\n", "\n", " table.dataframe tr {\n", " border: 1px #CCCCCC;\n", " }\n", "\n", " div.cell.selected {\n", " border-radius: 0px;\n", " }\n", "\n", " div.cell.edit_mode {\n", " border-radius: 0px;\n", " border: thin solid #CF5804;\n", " }\n", "\n", " span.ansiblue {\n", " color: #00A397;\n", " }\n", "\n", " span.ansigray {\n", " color: #d8d8d8;\n", " }\n", "\n", " span.ansigreen {\n", " color: #688A0A;\n", " }\n", "\n", " span.ansipurple {\n", " color: #975DDE;\n", " }\n", "\n", " span.ansired {\n", " color: #D43132;\n", " }\n", "\n", " span.ansiyellow {\n", " color: #D9AA00;\n", " }\n", "\n", " div.output_stderr {\n", " background-color: #D43132;\n", " }\n", "\n", " div.output_stderr pre {\n", " color: #e8e8e8;\n", " }\n", "\n", " .cm-s-ipython.CodeMirror {\n", " background: #F8F8F8;\n", " }\n", "\n", " .cm-s-ipython div.CodeMirror-selected {\n", " background: #e8e8e8 !important;\n", " }\n", "\n", " .cm-s-ipython .CodeMirror-gutters {\n", " background: #F8F8F8;\n", " border-right: 0px;\n", " }\n", "\n", " .cm-s-ipython .CodeMirror-linenumber {\n", " color: #b8b8b8;\n", " }\n", "\n", " .cm-s-ipython .CodeMirror-cursor {\n", " border-left: 1px solid #585858 !important;\n", " }\n", "\n", " .cm-s-ipython span.cm-atom {\n", " color: #C74483;\n", " }\n", "\n", " .cm-s-ipython span.cm-number {\n", " color: #C74483;\n", " }\n", "\n", " .cm-s-ipython span.cm-property,\n", " .cm-s-ipython span.cm-attribute {\n", " color: #688A0A;\n", " }\n", "\n", " .cm-s-ipython span.cm-keyword {\n", " font-weight: normal;\n", " color: #D43132;\n", " }\n", "\n", " .cm-s-ipython span.cm-string {\n", " color: #D9AA00;\n", " }\n", "\n", " .cm-s-ipython span.cm-operator {\n", " font-weight: normal;\n", " }\n", "\n", " .cm-s-ipython span.cm-builtin {\n", " color: #2B88D9;\n", " }\n", "\n", " .cm-s-ipython span.cm-variable {\n", " color: #00A397;\n", " }\n", "\n", " .cm-s-ipython span.cm-variable-2 {\n", " color: #2B88D9;\n", " }\n", "\n", " .cm-s-ipython span.cm-def {\n", " color: #00A397;\n", " }\n", "\n", " .cm-s-ipython span.cm-error {\n", " background: #FFBDBD;\n", " color: #D43132;\n", " }\n", "\n", " .cm-s-ipython span.cm-tag {\n", " color: #D43132;\n", " }\n", "\n", " .cm-s-ipython span.cm-link {\n", " color: #975DDE;\n", " }\n", "\n", " .cm-s-ipython .CodeMirror-matchingbracket {\n", " text-decoration: underline;\n", " !important;\n", " }\n", "</style>\n", "\n", "<script>\n", " MathJax.Hub.Config({\n", " TeX: {\n", " extensions: [\"AMSmath.js\"]\n", " },\n", " tex2jax: {\n", " inlineMath: [ ['$','$'], [\"\\\\(\",\"\\\\)\"] ],\n", " displayMath: [ ['$$','$$'], [\"\\\\[\",\"\\\\]\"] ]\n", " },\n", " displayAlign: 'center', // Change this to 'center' to center equations.\n", " \"HTML-CSS\": {\n", " scale:100,\n", " availableFonts: [],\n", " preferredFont:null,\n", " webFont: \"TeX\",\n", " styles: {'.MathJax_Display': {\"margin\": 4}}\n", " }\n", " });\n", "</script>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# code for loading the format for the notebook\n", "import os\n", "\n", "# path : store the current path to convert back to it later\n", "path = os.getcwd()\n", "os.chdir(os.path.join('..', '..', 'notebook_format'))\n", "\n", "from formats import load_style\n", "load_style(css_style='custom2.css', plot_style=False)" ] }, { "cell_type": "code", "execution_count": 2, "id": "d17fa0e4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Author: Ethen\n", "\n", "Last updated: 2023-04-16\n", "\n", "Python implementation: CPython\n", "Python version : 3.8.10\n", "IPython version : 8.4.0\n", "\n", "torch : 2.0.0\n", "datasets : 2.11.0\n", "transformers: 4.28.1\n", "evaluate : 0.4.0\n", "numpy : 1.23.2\n", "pandas : 1.4.3\n", "\n" ] } ], "source": [ "os.chdir(path)\n", "\n", "# 1. magic for inline plot\n", "# 2. magic to print version\n", "# 3. magic so that the notebook will reload external python modules\n", "# 4. magic to enable retina (high resolution) plots\n", "# https://gist.github.com/minrk/3301035\n", "%matplotlib inline\n", "%load_ext watermark\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import os\n", "import torch\n", "import evaluate\n", "import numpy as np\n", "import pandas as pd\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from time import perf_counter\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset, DatasetDict, disable_progress_bar\n", "from datasets.utils.logging import set_verbosity_error\n", "from transformers import (\n", " pipeline,\n", " Trainer,\n", " TrainingArguments,\n", " AutoConfig,\n", " AutoTokenizer,\n", " AutoModelForSequenceClassification,\n", " DataCollatorWithPadding\n", ")\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# prevent dataset from floading outputs to our notebook\n", "disable_progress_bar()\n", "set_verbosity_error()\n", "\n", "%watermark -a 'Ethen' -d -u -v -p torch,datasets,transformers,evaluate,numpy,pandas" ] }, { "cell_type": "markdown", "id": "0268da0e", "metadata": {}, "source": [ "# Response Knowledge Distillation" ] }, { "cell_type": "markdown", "id": "87cf9884", "metadata": {}, "source": [ "In this documentation, we'll deep dive into a technique called knowledge distillation that's commonly used to compress large model, a.k.a. teacher model, into a smaller model, a.k.a student model. The hope is that these student models, which typically have fewer layers or/and fewer neurons per layer will be capable of reproducing the behavior of teacher models while being more light weight. In other words, making the model more cost efficient when it comes to serving in production setting without lossing too much performance. And just to clarify, as knowledge distillation is a broad topic, there are two primary types of knowledge distillation, task-specific knowledge distillation (left) and task-agnostic knowledge distillation (right). Here, our primary focus will be the former.\n", "\n", "<img src=\"img/distillation_task.png\" width=\"70%\" height=\"70%\">\n", "\n", "Task specific response knowledge distillation involves optimizing a weighted combination of two objective functions\n", "\n", "\\begin{align}\n", "L = \\alpha L_{CE} + (1 - \\alpha) L_{KD} \\text{, where } \\alpha \\in [0, 1]\n", "\\end{align}\n", "\n", "$L_{CE}$ is the cross entropy loss between the student logit $z_s$ and our one hot encoded ground truth labels $y$:\n", "\n", "\\begin{align}\n", "L_{CE} = - \\sum^c_{j=1}y_j \\text{log} \\sigma_j(z_s, 1)\n", "\\end{align}\n", "\n", "Where $\\sigma_i$ is our softmax output that takes the model's logit, $z$ ($z_t$ stands for teacher model's logit, whereas $z_s$ stands for student model's logit), as well as a temperature scaling parameter, $T$, as its inputs. $\\sigma_i = \\frac{exp\\left(z_i / T \\right)}{\\sum_{j} \\exp\\left(z_j / T \\right)}$. Here, the temperature parameter for softmax function is 1, which makes this the standard loss function that we generally optimize towards in supervised classification settings. \n", "\n", "$L_{KD}$ For knowledge distillation loss part, we are essentially add a KL-divergence loss between teacher model's response with student model's response. By adding this loss function, we are training our student model so it will become better at mimicking similar predictions as the teacher.\n", "\n", "\\begin{align}\n", "L_{KD} = - T^2 \\sum^c_{j=1}\\sigma_j(z_t, T) \\text{log} \\frac{\\sigma_j(z_t, T)}{\\sigma_j(z_s, T)}\n", "\\end{align}\n", "\n", "The idea behind temperature scaling is that teacher model tend to assign extremely high predicted scores to the true class, as such it doesn't provide too much additional information beyond what dataset's ground truth label was already provided. To tackle this issue, temperature scaling acts as a scaling parameter to \"soften\" our predictions. The intuition behind this it allows us to learn \"ish\" concepts in our data, e.g. we have a 1-ish 7 (a 7 that looks like a 1, or more formally, although our model predicted 7 with the highest score, it still assign some amount of score to 1). Note:\n", "\n", "- When a student model is a lot smaller than a teacher model, we tend to keep a smaller temperature. Because as we raise this temperature parameter, the resulting predicted distribution may start to contain too much \"knowledge\" for the student to capture effectively.\n", "- Once our student model has been trained, temperature parameter $T$, is set back to 1 during inferencing stage.\n", "- There's a multiplication term $T^2$, in our knowledge distillation loss. Since magnitudes of the gradients produced by soft targets scale as $1/T^2$. It's important to add a multiplication term back to ensure contribution from the ground truth hard target and the teacher's predicted soft target remains roughly equal.\n", "\n", "As we can see, the main idea behind response knowledge distillation is that while training our student model, instead of solely optimizing for our task's original loss function using dataset's ground truth label (e.g. in classification task this may be cross entropy loss), we will augment it with the teacher model's predicted output probability. In our loss function we will have a parameter $\\alpha$ that controls weighting between the two loss function." ] }, { "cell_type": "markdown", "id": "b2c925d2", "metadata": {}, "source": [ "## Data Preprocessing" ] }, { "cell_type": "markdown", "id": "1d576158", "metadata": {}, "source": [ "For this example, we will be using qqp (Quora Question Pairs2) [text classification task]((https://huggingface.co/tasks/text-classification)) from the [glue benchmark](https://huggingface.co/datasets/glue). These are collection of question pairs from the community question-answering website Quora. Our task is to determine whether a pair of questions are semantically equivalent." ] }, { "cell_type": "code", "execution_count": 3, "id": "bdbbabb6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['question1', 'question2', 'label', 'idx'],\n", " num_rows: 363846\n", " })\n", " validation: Dataset({\n", " features: ['question1', 'question2', 'label', 'idx'],\n", " num_rows: 40430\n", " })\n", " test: Dataset({\n", " features: ['question1', 'question2', 'label', 'idx'],\n", " num_rows: 390965\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_dict = load_dataset('glue', 'qqp')\n", "dataset_dict" ] }, { "cell_type": "code", "execution_count": 4, "id": "d92e63f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'question1': 'What can one do after MBBS?',\n", " 'question2': 'What do i do after my MBBS ?',\n", " 'label': 1,\n", " 'idx': 3}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example = dataset_dict['train'][3]\n", "example" ] }, { "cell_type": "markdown", "id": "9f80c1da", "metadata": {}, "source": [ "## Teacher Model" ] }, { "cell_type": "markdown", "id": "f3b38af2", "metadata": {}, "source": [ "To establish our baseline, we'll piggyback on one of the pretrained models available from huggingface hub. In this case, we'll pick a teacher model that is already trained on our targeted dataset." ] }, { "cell_type": "code", "execution_count": 5, "id": "0ba918b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# of parameters: 109483778\n" ] } ], "source": [ "teacher_checkpoint = 'textattack/bert-base-uncased-QQP'\n", "teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_checkpoint)\n", "teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_checkpoint).to(device)\n", "print('# of parameters: ', teacher_model.num_parameters())" ] }, { "cell_type": "markdown", "id": "daed7ba0", "metadata": {}, "source": [ "We generate a sample prediction using our tokenizer and model. Double confirming our result matches with the [pipeline wrapper class](https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/pipelines)." ] }, { "cell_type": "code", "execution_count": 6, "id": "8369a332", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 2054, 2064, 2028, 2079, 2044, 16914, 5910, 1029, 102,\n", " 2054, 2079, 1045, 2079, 2044, 2026, 16914, 5910, 1029, 102]],\n", " device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],\n", " device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],\n", " device='cuda:0')}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized = teacher_tokenizer(\n", " example['question1'],\n", " example['question2'],\n", " return_tensors='pt'\n", ").to(teacher_model.device)\n", "tokenized" ] }, { "cell_type": "code", "execution_count": 7, "id": "7fb1591a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0223, 0.9777]], device='cuda:0')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "teacher_model.eval()\n", "with torch.no_grad():\n", " output = teacher_model(**tokenized)\n", " batch_scores = F.softmax(output.logits, dim=-1)\n", "\n", "batch_scores" ] }, { "cell_type": "code", "execution_count": 8, "id": "597fafa1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'label': 'LABEL_1', 'score': 0.9777140021324158}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier = pipeline(\"text-classification\", model=teacher_checkpoint, device=teacher_model.device)\n", "output = classifier({\"text\": example['question1'], \"text_pair\": example['question2']})\n", "output" ] }, { "cell_type": "markdown", "id": "a562bf7a", "metadata": {}, "source": [ "## Student Model" ] }, { "cell_type": "markdown", "id": "1abafbb5", "metadata": {}, "source": [ "As always, we are free to choose different student models and compare results, though as a general principle, we typically avoid distilling different model family against each other, as different inputs/tokens will result in different embeddings, and knowledge transfering different spaces tend to not work well.\n", "\n", "In the next code chunk, apart from the typically step of initiating our student model using `.from_pretrained` method, we also copy some additional config such as number of labels as well as label id to label name mapping from the teacher model's config." ] }, { "cell_type": "code", "execution_count": 9, "id": "15d87d23", "metadata": {}, "outputs": [], "source": [ "student_checkpoint = 'distilbert-base-uncased'\n", "student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)\n", "student_config = AutoConfig.from_pretrained(\n", " student_checkpoint,\n", " num_labels=teacher_model.config.num_labels,\n", " id2label=teacher_model.config.id2label,\n", " label2id=teacher_model.config.label2id\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "ae893d79", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "# of parameters: 66955010\n" ] } ], "source": [ "def student_model_init():\n", " student_model = AutoModelForSequenceClassification.from_pretrained(\n", " student_checkpoint,\n", " config=student_config\n", " ).to(device)\n", " return student_model\n", "\n", "\n", "student_model = student_model_init()\n", "print('# of parameters: ', student_model.num_parameters())" ] }, { "cell_type": "code", "execution_count": 11, "id": "e502b119", "metadata": {}, "outputs": [], "source": [ "def tokenize_dataset(dataset, tokenizer):\n", " def tokenize_fn(batch):\n", " return tokenizer(batch[\"question1\"], batch[\"question2\"], truncation=True)\n", "\n", " return dataset.map(\n", " tokenize_fn,\n", " batched=True,\n", " num_proc=8,\n", " remove_columns=[\"question1\", \"question2\", \"idx\"]\n", " )" ] }, { "cell_type": "code", "execution_count": 12, "id": "80ed9740", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'label': 0,\n", " 'input_ids': [101,\n", " 2129,\n", " 2003,\n", " 1996,\n", " 2166,\n", " 1997,\n", " 1037,\n", " 8785,\n", " 3076,\n", " 1029,\n", " 2071,\n", " 2017,\n", " 6235,\n", " 2115,\n", " 2219,\n", " 6322,\n", " 1029,\n", " 102,\n", " 2029,\n", " 2504,\n", " 1997,\n", " 17463,\n", " 8156,\n", " 2003,\n", " 2438,\n", " 2005,\n", " 1996,\n", " 11360,\n", " 1046,\n", " 14277,\n", " 2102,\n", " 2629,\n", " 1029,\n", " 102],\n", " 'attention_mask': [1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1,\n", " 1]}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_dict_student_tokenized = tokenize_dataset(dataset_dict, student_tokenizer)\n", "dataset_dict_student_tokenized['train'][0]" ] }, { "cell_type": "markdown", "id": "08c92452", "metadata": {}, "source": [ "For model performance, we'll compute some of the standard text classification metrics, Huggingface evaluate allows us to combine multiple metric's calculation in one go using the `.combine` method. As `roc_auc` expects a different input (it requires the predicted score instead of predicted labels) compared to `f1`, `precision`, `recall`, we load it separately." ] }, { "cell_type": "code", "execution_count": 13, "id": "abf7c4d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'f1': 1.0, 'precision': 1.0, 'recall': 1.0}\n" ] } ], "source": [ "clf_metrics = evaluate.combine([\"f1\", \"precision\", \"recall\"])\n", "roc_auc_metric = evaluate.load(\"roc_auc\")\n", "\n", "results = clf_metrics.compute(predictions=[0, 1], references=[0, 1])\n", "print(results)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f4d5aa84", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(pred):\n", " scores, labels = pred\n", " predictions = np.argmax(scores, axis=1)\n", " metrics = clf_metrics.compute(predictions=predictions, references=labels)\n", " metrics['roc_auc'] = roc_auc_metric.compute(prediction_scores=scores[:, 1], references=labels)['roc_auc']\n", " return metrics" ] }, { "cell_type": "markdown", "id": "8f79457c", "metadata": {}, "source": [ "In the next few code chunk, we'll first train a student model with and without knowledge distillation for comparison." ] }, { "cell_type": "code", "execution_count": 15, "id": "17d04c76", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='11372' max='11372' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [11372/11372 26:47, Epoch 2/2]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " <th>F1</th>\n", " <th>Precision</th>\n", " <th>Recall</th>\n", " <th>Roc Auc</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>1</td>\n", " <td>0.282800</td>\n", " <td>0.261373</td>\n", " <td>0.850860</td>\n", " <td>0.821411</td>\n", " <td>0.882499</td>\n", " <td>0.955464</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.179600</td>\n", " <td>0.247799</td>\n", " <td>0.866363</td>\n", " <td>0.852121</td>\n", " <td>0.881088</td>\n", " <td>0.963130</td>\n", " </tr>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=11372, training_loss=0.2587387173518806, metrics={'train_runtime': 1607.3946, 'train_samples_per_second': 452.715, 'train_steps_per_second': 7.075, 'total_flos': 1.4672952700483704e+16, 'train_loss': 0.2587387173518806, 'epoch': 2.0})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_size = 64\n", "num_train_epochs = 2\n", "learning_rate = 0.0001\n", "weight_decay = 0.01\n", "\n", "student_finetuned_checkpoint = \"distilbert-base-uncased-finetuned-qqp\"\n", "student_training_args = TrainingArguments(\n", " output_dir=student_finetuned_checkpoint,\n", " num_train_epochs=num_train_epochs,\n", " learning_rate=learning_rate,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " weight_decay=weight_decay,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=2,\n", " load_best_model_at_end=True\n", ")\n", "\n", "student_trainer = Trainer(\n", " model_init=student_model_init,\n", " args=student_training_args,\n", " tokenizer=student_tokenizer, \n", " train_dataset=dataset_dict_student_tokenized[\"train\"],\n", " eval_dataset=dataset_dict_student_tokenized[\"validation\"],\n", " compute_metrics=compute_metrics\n", ")\n", "student_trainer.train()" ] }, { "cell_type": "markdown", "id": "094be856", "metadata": {}, "source": [ "In order for us to finetune a model using knowledge distillation, we will subclass the `TrainingArguments` to include our two hyperparameters, $\\alpha$ and $T$, as well as `Trainer` to mainly overwrite its `compute_loss` method so we can add our knowledge distillation loss term." ] }, { "cell_type": "code", "execution_count": 16, "id": "9e6da70f", "metadata": {}, "outputs": [], "source": [ "class DistillationTrainingArguments(TrainingArguments):\n", " def __init__(self, *args, alpha=0.5, temperature=1.5, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.alpha = alpha\n", " self.temperature = temperature\n", "\n", "\n", "class DistillationTrainer(Trainer):\n", " def __init__(self, *args, teacher_model=None, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.teacher = teacher_model\n", " # place teacher on same device as student\n", " self._move_model_to_device(self.teacher, self.model.device)\n", " self.teacher.eval()\n", "\n", " self.kl_div_loss = nn.KLDivLoss(reduction=\"batchmean\")\n", "\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " # compute student and teacher output\n", " outputs_student = model(**inputs)\n", " with torch.no_grad():\n", " outputs_teacher = self.teacher(**inputs)\n", "\n", " # Soften probabilities and compute distillation loss\n", " # note, the kl divergence loss expects the input to be in log-space\n", " # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html\n", " distillation_loss = self.kl_div_loss(\n", " F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),\n", " F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)\n", " ) * (self.args.temperature ** 2)\n", " # Return weighted student loss\n", " loss = self.args.alpha * outputs_student.loss + (1. - self.args.alpha) * distillation_loss\n", " return (loss, outputs_student) if return_outputs else loss" ] }, { "cell_type": "code", "execution_count": 17, "id": "3813c8c0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", " <div>\n", " \n", " <progress value='11372' max='11372' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", " [11372/11372 44:05, Epoch 2/2]\n", " </div>\n", " <table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: left;\">\n", " <th>Epoch</th>\n", " <th>Training Loss</th>\n", " <th>Validation Loss</th>\n", " <th>F1</th>\n", " <th>Precision</th>\n", " <th>Recall</th>\n", " <th>Roc Auc</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <td>1</td>\n", " <td>0.418600</td>\n", " <td>0.403182</td>\n", " <td>0.813883</td>\n", " <td>0.881096</td>\n", " <td>0.756198</td>\n", " <td>0.953579</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.352200</td>\n", " <td>0.396250</td>\n", " <td>0.853818</td>\n", " <td>0.878482</td>\n", " <td>0.830501</td>\n", " <td>0.961582</td>\n", " </tr>\n", " </tbody>\n", "</table><p>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=11372, training_loss=0.4020600615778818, metrics={'train_runtime': 2645.7956, 'train_samples_per_second': 275.037, 'train_steps_per_second': 4.298, 'total_flos': 1.4672952700483704e+16, 'train_loss': 0.4020600615778818, 'epoch': 2.0})" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "student_distillation_checkpoint = \"distilbert-base-uncased-finetuned-distillation-qqp\"\n", "student_distillation_training_args = DistillationTrainingArguments(\n", " output_dir=student_distillation_checkpoint,\n", " num_train_epochs=num_train_epochs,\n", " learning_rate=learning_rate,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " weight_decay=weight_decay,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " save_total_limit=2,\n", " load_best_model_at_end=True,\n", " alpha=0.8\n", ")\n", "\n", "student_distillation_trainer = DistillationTrainer(\n", " model_init=student_model_init,\n", " args=student_distillation_training_args,\n", " tokenizer=student_tokenizer,\n", " teacher_model=teacher_model, \n", " train_dataset=dataset_dict_student_tokenized['train'],\n", " eval_dataset=dataset_dict_student_tokenized['validation'],\n", " compute_metrics=compute_metrics\n", ")\n", "student_distillation_trainer.train()" ] }, { "cell_type": "markdown", "id": "b4b32405", "metadata": {}, "source": [ "## Benchmark" ] }, { "cell_type": "markdown", "id": "e127b4ba", "metadata": {}, "source": [ "When determining which model to move forward with production, we usually look at model performance, latency, as well as memory (a.k.a model size). We'll create a helper class for measuring these key aspects, run our models through it for a fair comparison." ] }, { "cell_type": "code", "execution_count": 18, "id": "a0540811", "metadata": {}, "outputs": [], "source": [ "class Benchmark:\n", "\n", " def __init__(\n", " self,\n", " dataset,\n", " latency_warmup: int = 10,\n", " latency_rounds: int = 100,\n", " perf_batch_size: int = 128,\n", " perf_round_digits: int = 3\n", " ):\n", " self.dataset = dataset\n", " self.latency_warmup = latency_warmup\n", " self.latency_rounds = latency_rounds\n", " self.perf_batch_size = perf_batch_size\n", " self.perf_round_digits = perf_round_digits\n", "\n", " self.temp_model_path = \"model.pt\"\n", "\n", " def run(self, tokenizer, model, run_name):\n", " \"\"\"run benchmark for a given tokenizer and model\n", " we can provide a run_name to differentiate the results\n", " from different runs in the final dictionary.\n", " \n", " e.g.\n", " {\n", " \"run_name\": {\n", " 'size_mb': 417.73,\n", " 'num_parameters': 109483778,\n", " 'latency_avg_ms': 8.33,\n", " 'latency_std_ms': 1.16,\n", " 'f1': 0.878,\n", " 'precision': 0.867,\n", " 'recall': 0.89,\n", " 'roc_auc': 0.968\n", " }\n", " }\n", " \"\"\"\n", " model.eval()\n", " \n", " size = self.compute_size(model)\n", " latency = self.compute_latency(tokenizer, model)\n", " performance = self.compute_performance(tokenizer, model)\n", "\n", " # merge various metrics into one single dictionary\n", " metrics = {**size, **latency, **performance}\n", " return {run_name: metrics}\n", " \n", " def predict(self, example, tokenizer, model):\n", " inputs = tokenizer(\n", " example[\"question1\"],\n", " example[\"question2\"],\n", " return_tensors=\"pt\"\n", " ).to(model.device)\n", " with torch.no_grad():\n", " output = model(**inputs.to(model.device))\n", "\n", " return output\n", "\n", " def compute_size(self, model):\n", " \"\"\"save the model's parameter temporarily to local path for calculating model size.\n", " Once calculation is done, purge the checkpoint.\n", " Size is reported in megabtyes.\n", "\n", " https://pytorch.org/tutorials/beginner/saving_loading_models.html\n", " \"\"\"\n", " torch.save(model.state_dict(), self.temp_model_path)\n", " size_mb = os.path.getsize(self.temp_model_path) / (1024 * 1024)\n", " size_mb = round(size_mb, 2)\n", " os.remove(self.temp_model_path)\n", " print(f\"Model size (MB): {size_mb}\")\n", " print(f\"# of parameters: {model.num_parameters()}\")\n", " return {\"size_mb\": size_mb, \"num_parameters\": model.num_parameters()}\n", " \n", " def compute_latency(self, tokenizer, model):\n", " \"\"\"\n", " Pick the first example of the input dataset, compute the average latency as well as\n", " standard deviation over a configurable number of runs.\n", " Latency is reported in milliseconds.\n", " \"\"\"\n", " example = self.dataset[0]\n", " latencies = []\n", "\n", " for _ in range(self.latency_warmup):\n", " _ = self.predict(example, tokenizer, model)\n", "\n", " for _ in range(self.latency_rounds):\n", " start_time = perf_counter()\n", " _ = self.predict(example, tokenizer, model)\n", " latency = perf_counter() - start_time\n", " latencies.append(latency)\n", "\n", " # Compute run statistics\n", " latency_avg_ms = round(1000 * np.mean(latencies), 2)\n", " latency_std_ms = round(1000 * np.std(latencies), 2)\n", " print(f\"Average latency (ms): {latency_avg_ms} +\\- {latency_std_ms}\")\n", " return {\"latency_avg_ms\": latency_avg_ms, \"latency_std_ms\": latency_std_ms}\n", " \n", " def compute_performance(self, tokenizer, model):\n", " \"\"\"compute f1/precision/recall/roc_auc metrics around sequence classification.\"\"\"\n", " clf_metrics = evaluate.combine([\"f1\", \"precision\", \"recall\"])\n", " roc_auc_metric = evaluate.load(\"roc_auc\")\n", "\n", " scores = []\n", " predictions = []\n", " references = []\n", " \n", " dataset_tokenized = tokenize_dataset(self.dataset, tokenizer)\n", " \n", " data_collator = DataCollatorWithPadding(tokenizer)\n", " data_loader = DataLoader(dataset_tokenized, batch_size=self.perf_batch_size, collate_fn=data_collator)\n", " for example in data_loader:\n", " labels = example.pop(\"labels\")\n", " with torch.no_grad():\n", " output = model(**example.to(model.device))\n", " score = F.softmax(output.logits, dim=-1)\n", " prediction = score.argmax(dim=-1)\n", "\n", " scores += tensor_to_list(score[:, 1])\n", " predictions += tensor_to_list(prediction)\n", " references += tensor_to_list(labels)\n", "\n", " metrics = clf_metrics.compute(predictions=predictions, references=references)\n", " metrics[\"roc_auc\"] = roc_auc_metric.compute(prediction_scores=scores, references=references)[\"roc_auc\"]\n", " for metric, value in metrics.items():\n", " metrics[metric] = round(value, self.perf_round_digits)\n", "\n", " return metrics\n", " \n", " \n", "def tensor_to_list(tensor):\n", " return tensor.cpu().numpy().tolist()" ] }, { "cell_type": "code", "execution_count": 19, "id": "6c1b7b2b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model size (MB): 417.72\n", "# of parameters: 109483778\n", "Average latency (ms): 13.54 +\\- 0.07\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] } ], "source": [ "benchmark_metrics_dict = {}\n", "benchmark = Benchmark(dataset_dict[\"validation\"])\n", "benchmark_metrics = benchmark.run(teacher_tokenizer, teacher_model, \"bert_uncased_teacher\")\n", "benchmark_metrics_dict.update(benchmark_metrics)" ] }, { "cell_type": "code", "execution_count": 20, "id": "0a793671", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model size (MB): 255.45\n", "# of parameters: 66955010\n", "Average latency (ms): 7.66 +\\- 0.31\n" ] } ], "source": [ "benchmark_metrics = benchmark.run(\n", " student_tokenizer,\n", " student_trainer.model,\n", " \"distilbert_student\"\n", ")\n", "benchmark_metrics_dict.update(benchmark_metrics)" ] }, { "cell_type": "code", "execution_count": 21, "id": "301a2880", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model size (MB): 255.45\n", "# of parameters: 66955010\n", "Average latency (ms): 7.11 +\\- 0.09\n" ] } ], "source": [ "benchmark_metrics = benchmark.run(\n", " student_tokenizer,\n", " student_distillation_trainer.model,\n", " \"distilbert_distillation_student\"\n", ")\n", "benchmark_metrics_dict.update(benchmark_metrics)" ] }, { "cell_type": "code", "execution_count": 22, "id": "e6f53b7e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>size_mb</th>\n", " <th>num_parameters</th>\n", " <th>latency_avg_ms</th>\n", " <th>latency_std_ms</th>\n", " <th>f1</th>\n", " <th>precision</th>\n", " <th>recall</th>\n", " <th>roc_auc</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>bert_uncased_teacher</th>\n", " <td>417.72</td>\n", " <td>109483778</td>\n", " <td>13.54</td>\n", " <td>0.07</td>\n", " <td>0.878</td>\n", " <td>0.867</td>\n", " <td>0.890</td>\n", " <td>0.968</td>\n", " </tr>\n", " <tr>\n", " <th>distilbert_student</th>\n", " <td>255.45</td>\n", " <td>66955010</td>\n", " <td>7.66</td>\n", " <td>0.31</td>\n", " <td>0.866</td>\n", " <td>0.852</td>\n", " <td>0.881</td>\n", " <td>0.963</td>\n", " </tr>\n", " <tr>\n", " <th>distilbert_distillation_student</th>\n", " <td>255.45</td>\n", " <td>66955010</td>\n", " <td>7.11</td>\n", " <td>0.09</td>\n", " <td>0.854</td>\n", " <td>0.878</td>\n", " <td>0.831</td>\n", " <td>0.962</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " size_mb num_parameters latency_avg_ms \\\n", "bert_uncased_teacher 417.72 109483778 13.54 \n", "distilbert_student 255.45 66955010 7.66 \n", "distilbert_distillation_student 255.45 66955010 7.11 \n", "\n", " latency_std_ms f1 precision recall \\\n", "bert_uncased_teacher 0.07 0.878 0.867 0.890 \n", "distilbert_student 0.31 0.866 0.852 0.881 \n", "distilbert_distillation_student 0.09 0.854 0.878 0.831 \n", "\n", " roc_auc \n", "bert_uncased_teacher 0.968 \n", "distilbert_student 0.963 \n", "distilbert_distillation_student 0.962 " ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame.from_dict(benchmark_metrics_dict, orient=\"index\")" ] }, { "cell_type": "markdown", "id": "ecdd3efa", "metadata": {}, "source": [ "The final table is a comparison on our teacher model (bert), and two student model (distilbert), where one of the students was trained with knowledge distilation loss, and the other wasn't. Quick observations are: we can definitely shrink our model size and improve latency by using a student model without much loss in terms of model performance. Note, we also didn't spend too much time tuning additional loss weighting, $\\alpha$, and temperature scaling, $T$ hyperparameters that comes with knowledge distillation." ] }, { "cell_type": "markdown", "id": "3d31fafb-d3ab-4086-b03e-c53ee4a312cf", "metadata": {}, "source": [ "## Notes" ] }, { "cell_type": "markdown", "id": "2f3cda57-2a86-417b-967f-e1b5223d02cb", "metadata": {}, "source": [ "It is not surprising that large models tend to give superior performance. As software and hardware continues to advance, the barrier for training or accessing these large models will continue to lower, making scaling up still a promising approach to obtain better performance for whatever applications we care about. Despite that being said, there will always be scenarios where smaller models are preferable, and knowledge distillation [[6]](https://arxiv.org/abs/1503.02531) is a popular way for compressing our large models into less expensive ones while still retaining majority of its performance.\n", "\n", "As mentioned in DistilBERT [[7]](https://arxiv.org/abs/1910.01108), they were able to compress a 110 million parameters BERT-base model to 66 million parameters DistilBERT model while retaining 97% of the original performance when measured on GLUE benchmark's dev set. If we were to distill a pre-trained model ourselves, it might be worth mentioning that a better student initialization strategy is to make sure our students are \"well read\" [[8]](https://arxiv.org/abs/1908.08962). i.e. Our students typically have the same architecture with the only variations on smaller number of layers, instead of initializing them via truncating teacher layers or by taking one layer out of two like in DistilBERT, we should initialized from weights that have also gone through similar pre-training procedure as our teacher." ] }, { "cell_type": "markdown", "id": "6befdd48", "metadata": {}, "source": [ "# Reference" ] }, { "cell_type": "markdown", "id": "681c7e41", "metadata": {}, "source": [ "- [[1]](https://www.philschmid.de/knowledge-distillation-bert-transformers) Blog: Task-specific knowledge distillation for BERT using Transformers & Amazon SageMaker\n", "- [[2]](https://medium.com/huggingface/distilbert-8cf3380435b5) Blog: Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT\n", "- [[3]](https://neptune.ai/blog/knowledge-distillation) Blog: Knowledge Distillation: Principles, Algorithms, Applications\n", "- [[4]](https://lewtun.github.io/blog/weeknotes/nlp/huggingface/transformers/2021/01/17/wknotes-distillation-and-generation.html) Blog: Weeknotes: Distilling distilled transformers\n", "- [[5]](https://intellabs.github.io/distiller/knowledge_distillation.html) Doc: Neural Network Distiller - Knowledge Distillation\n", "- [[6]](https://arxiv.org/abs/1503.02531) Paper: Geoffrey Hinton, Oriol Vinyals, et al. - Distilling the Knowledge in a Neural Network - 2015\n", "- [[7]](https://arxiv.org/abs/1910.01108) Paper: Victor Sanh, Lysandre Debut, Julien Chaumond, Thomas Wolf - DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter - 2019\n", "- [[8]](https://arxiv.org/abs/1908.08962) Paper: Iulia Turc, Ming-Wei Chang, Kenton Lee, Kristina Toutanova - Well-Read Students Learn Better: On the Importance of Pre-training Compact Models - 2019" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": true, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "233.719px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }