# Fine-tuning a BERT model with skorch and Hugging Face

In this notebook, we follow the fine-tuning guideline from [Hugging Face documentation](https://huggingface.co/docs/transformers/training). Please check it out if we you want to know more about BERT and fine-tuning. Here, we assume that you're familiar with the general ideas.

You will learn how to:
- integrate the [Hugging Face transformers](https://huggingface.co/docs/transformers/index) library with skorch
- use skorch to fine-tune a BERT model on a text classification task
- use skorch with the [Hugging Face accelerate](https://huggingface.co/docs/accelerate/index) library for automatic mixed precision (AMP) training

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_Finetuning.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

The first part of the notebook requires hugginface `transformers` as an additional dependency. If you have not already installed it, you can do so like this:

`python -m pip install transformers`

In [1]:
import subprocess

# Installation on Google Colab
try:
    import google.colab
    subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers'])
except ImportError:
    pass

## Imports

In [2]:
import numpy as np
import torch
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, ProgressBar
from skorch.hf import HuggingfacePretrainedTokenizer
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

## Parameters

Change the values below if you want to try out different model architectures and hyper-parameters.

In [3]:
# Choose a tokenizer and BERT model that work together
TOKENIZER = "distilbert-base-uncased"
PRETRAINED_MODEL = "distilbert-base-uncased"

# model hyper-parameters
OPTMIZER = torch.optim.AdamW
LR = 5e-5
MAX_EPOCHS = 3
CRITERION = nn.CrossEntropyLoss
BATCH_SIZE = 8

# device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data

In [4]:
dataset = fetch_20newsgroups()

For this notebook, we're making use the 20 newsgroups dataset. It is a text classification dataset with 20 classes. A decent score would be to reach 89% accuracy out of sample. For more details, read the description below:

In [5]:
print(dataset.DESCR.split('Usage')[0])

.. _20newsgroups_dataset:

The 20 newsgroups text dataset
------------------------------

The 20 newsgroups dataset comprises around 18000 newsgroups posts on
20 topics split in two subsets: one for training (or development)
and the other one for testing (or for performance evaluation). The split
between the train and test set is based upon a messages posted before
and after a specific date.

This module contains two loaders. The first one,
:func:`sklearn.datasets.fetch_20newsgroups`,
returns a list of the raw texts that can be fed to text feature
extractors such as :class:`~sklearn.feature_extraction.text.CountVectorizer`
with custom parameters so as to extract feature vectors.
The second one, :func:`sklearn.datasets.fetch_20newsgroups_vectorized`,
returns ready-to-use features, i.e., it is not necessary to use a feature
extractor.

**Data Set Characteristics:**

    Classes                     20
    Samples total            18846
    Dimensionality               1
    Features      

In [6]:
dataset.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [7]:
X = dataset.data
y = dataset.target

In [8]:
X_train, X_test, y_train, y_test, = train_test_split(X, y, stratify=y, random_state=0)

In [9]:
X_train[:2]

['From: hilmi-er@dsv.su.se (Hilmi Eren)\nSubject: Re: ARMENIA SAYS IT COULD SHOOT DOWN TURKISH PLANES (Henrik)\nLines: 53\nNntp-Posting-Host: alban.dsv.su.se\nReply-To: hilmi-er@dsv.su.se (Hilmi Eren)\nOrganization: Dept. of Computer and Systems Sciences, Stockholm University\n\n\n  \n|>      henrik@quayle.kpc.com writes:\n\n\n|>\tThe Armenians in Nagarno-Karabagh are simply DEFENDING their RIGHTS\n|>        to keep their homeland and it is the AZERIS that are INVADING their \n|>        territorium...\n\t\n\n\tHomeland? First Nagarno-Karabagh was Armenians homeland today\n\tFizuli, Lacin and several villages (in Azerbadjan)\n\tare their homeland. Can\'t you see the\n\tthe  "Great Armenia" dream in this? With facist methods like\n\tkilling, raping and bombing villages. The last move was the \n\tblast of a truck with 60 kurdish refugees, trying to\n\tescape the from Lacin, a city that was "given" to the Kurds\n\tby the Armenians. \n\n\n|>       However, I hope that the Armenians WILL for

## Prepare the training

We want to use a linear learning rate schedule that linearly decreases the learning rate during training.

In [10]:
num_training_steps = MAX_EPOCHS * (len(X_train) // BATCH_SIZE + 1)

def lr_schedule(current_step):
    factor = float(num_training_steps - current_step) / float(max(1, num_training_steps))
    assert factor > 0
    return factor

Next we wrap the BERT module itself inside a simple `nn.Module`. The only real work for us here is to load the pretrained model and to return the _logits_ from the model output. The rest of the outputs is not needed.

In [11]:
class BertModule(nn.Module):
    def __init__(self, name, num_labels):
        super().__init__()
        self.name = name
        self.num_labels = num_labels
        
        self.reset_weights()
        
    def reset_weights(self):
        self.bert = AutoModelForSequenceClassification.from_pretrained(
            self.name, num_labels=self.num_labels
        )
        
    def forward(self, **kwargs):
        pred = self.bert(**kwargs)
        return pred.logits

### Tokenizer

We make use of `HuggingfacePretrainedTokenizer`, which is a wrapper that skorch provides to use the tokenizers from Hugging Face. In this instance, we use a tokenizer that was pretrained in conjunction with BERT. The tokenizer is automatically downloaded if not already present. More on Hugging Face tokenizers can be found [here](https://huggingface.co/docs/tokenizers/index).

## Training

### Putting it all togther

Now we can put together all the parts from above. There is nothing special going on here, we simply use an sklearn `Pipeline` to chain the `HuggingfacePretrainedTokenizer` and the neural net. Using skorch's `NeuralNetClassifier`, we make sure to pass the `BertModule` as the first argument and to set the number of labels based on `y_train`. The criterion is `CrossEntropyLoss` because we return the logits. Moreover, we make use of the learning rate schedule we defined above, and we add the `ProgressBar` callback to monitor our progress.

In [12]:
pipeline = Pipeline([
    ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)),
    ('net', NeuralNetClassifier(
        BertModule,
        module__name=PRETRAINED_MODEL,
        module__num_labels=len(set(y_train)),
        optimizer=OPTMIZER,
        lr=LR,
        max_epochs=MAX_EPOCHS,
        criterion=CRITERION,
        batch_size=BATCH_SIZE,
        iterator_train__shuffle=True,
        device=DEVICE,
        callbacks=[
            LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'),
            ProgressBar(),
        ],
    )),
])

Since we are using skorch, we could now take this pipeline to run a grid search or other kind of hyper-parameter sweep to figure out the best hyper-parameters for this model. E.g. we could try out a different BERT model or a different `max_length`.

### Fitting

In [13]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)

In [14]:
%time pipeline.fit(X_train, y_train)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_clas

  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        [36m1.1628[0m       [32m0.8338[0m        [35m0.5839[0m  179.8571


                                                                   

      2        [36m0.3709[0m       [32m0.8751[0m        [35m0.4214[0m  178.7779


                                                                     

      3        [36m0.1523[0m       [32m0.8910[0m        [35m0.3945[0m  178.4507




CPU times: user 7min 17s, sys: 1min 56s, total: 9min 14s
Wall time: 9min 29s


Pipeline(steps=[('tokenizer',
                 HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')),
                ('net',
                 <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=BertModule(
    (bert): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddin...
                (lin1): Linear(in_features=768, out_features=3072, bias=True)
                (lin2): Linear(in_features=3072, out_features=768, bias=True)
                (activation): GELUActivation()
              )
              (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            )
          )
        )
      )
      (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
      (classifier): Linear(in_features=768, out_features=20, bias=True)
      (dropout): Dropout(p=0.2, inplace

### Evaluation

In [15]:
%%time
with torch.inference_mode():
    y_pred = pipeline.predict(X_test)

CPU times: user 26.8 s, sys: 68.5 ms, total: 26.9 s
Wall time: 24.6 s


In [16]:
accuracy_score(y_test, y_pred)

0.8985507246376812

We can be happy with the results. We set ourselves the goal to reach or exceed 89% accuracy on the test set and we managed to do that.

## Training with automatic mixed precision (AMP)

For this to work, you need:
- A GPU that is capable of mixed precision training
- The [accelerate library](https://huggingface.co/docs/accelerate/index), which you can install as: `python -m pip install 'accelerate>=0.11'`.
- skorch version 0.12 or installed from the current master branch (`python -m pip install git+https://github.com/skorch-dev/skorch.git`)

Again, we assume that you're familiar with the general concept of mixed precision training. For more information on how skorch integrates with accelerate, please consult the [skorch docs](https://skorch.readthedocs.io/en/latest/user/huggingface.html#accelerate).

In [17]:
import subprocess

subprocess.run(['python', '-m', 'pip', 'install', 'accelerate>=0.11'])

CompletedProcess(args=['python', '-m', 'pip', 'install', 'accelerate>=0.11'], returncode=0)

In [18]:
from accelerate import Accelerator
from skorch.hf import AccelerateMixin

In [19]:
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
    """NeuralNetClassifier with accelerate support"""

In [20]:
accelerator = Accelerator(mixed_precision='fp16')

In [21]:
pipeline2 = Pipeline([
    ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)),
    ('net', AcceleratedNet(                   # <= changed
        BertModule,
        accelerator=accelerator,              # <= changed
        module__name=PRETRAINED_MODEL,
        module__num_labels=len(set(y_train)),
        optimizer=OPTMIZER,
        lr=LR,
        max_epochs=MAX_EPOCHS,
        criterion=CRITERION,
        batch_size=BATCH_SIZE,
        iterator_train__shuffle=True,
        # device=DEVICE,                      # <= changed
        callbacks=[
            LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'),
            ProgressBar(),
        ],
    )),
])

In [22]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)

In [23]:
pipeline2.fit(X_train, y_train)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_clas

  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        [36m1.1681[0m       [32m0.8397[0m        [35m0.5697[0m  79.4182




      2        [36m0.3479[0m       [32m0.8863[0m        [35m0.4011[0m  78.2318


                                                                     

      3        [36m0.1438[0m       [32m0.8933[0m        [35m0.3853[0m  77.5366




Pipeline(steps=[('tokenizer',
                 HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')),
                ('net',
                 <class '__main__.AcceleratedNet'>[initialized](
  module_=BertModule(
    (bert): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(...
                (lin1): Linear(in_features=768, out_features=3072, bias=True)
                (lin2): Linear(in_features=3072, out_features=768, bias=True)
                (activation): GELUActivation()
              )
              (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            )
          )
        )
      )
      (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
      (classifier): Linear(in_features=768, out_features=20, bias=True)
      (dropout): Dropout(p=0.2, inplac

In [24]:
%%time
with torch.inference_mode():
    y_pred = pipeline2.predict(X_test)

CPU times: user 23.6 s, sys: 219 ms, total: 23.8 s
Wall time: 24.7 s


In [25]:
accuracy_score(y_test, y_pred)

0.9070342877341817

Using AMP, we could reduce our training and prediction time by half, while attaining the same scores.