{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04-transformers-text-classification.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8ag5ANQPJ_j9",
        "colab_type": "text"
      },
      "source": [
        "# Finetune 🤗 Transformers Models with PyTorch Lightning ⚡\n",
        "\n",
        "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`. Then, we write a class to perform text classification on any dataset from the[ GLUE Benchmark](https://gluebenchmark.com/). (We just show CoLA and MRPC due to constraint on compute/disk)\n",
        "\n",
        "[HuggingFace's NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola) can help you get a feel for the two datasets we will use and what tasks they are solving for.\n",
        "\n",
        "---\n",
        "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
        "  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
        "  - Ask a question on [the forum](https://forums.pytorchlightning.ai/)\n",
        "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)\n",
        "\n",
        "  - [HuggingFace datasets](https://github.com/huggingface/datasets)\n",
        "  - [HuggingFace transformers](https://github.com/huggingface/transformers)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fqlsVTj7McZ3",
        "colab_type": "text"
      },
      "source": [
        "### Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OIhHrRL-MnKK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install pytorch-lightning datasets transformers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6yuQT_ZQMpCg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from argparse import ArgumentParser\n",
        "from datetime import datetime\n",
        "from typing import Optional\n",
        "\n",
        "import datasets\n",
        "import numpy as np\n",
        "import pytorch_lightning as pl\n",
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    AdamW,\n",
        "    AutoModelForSequenceClassification,\n",
        "    AutoConfig,\n",
        "    AutoTokenizer,\n",
        "    get_linear_schedule_with_warmup,\n",
        "    glue_compute_metrics\n",
        ")"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ORJfiuiNZ_N",
        "colab_type": "text"
      },
      "source": [
        "## GLUE DataModule"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jW9xQhZxMz1G",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class GLUEDataModule(pl.LightningDataModule):\n",
        "\n",
        "    task_text_field_map = {\n",
        "        'cola': ['sentence'],\n",
        "        'sst2': ['sentence'],\n",
        "        'mrpc': ['sentence1', 'sentence2'],\n",
        "        'qqp': ['question1', 'question2'],\n",
        "        'stsb': ['sentence1', 'sentence2'],\n",
        "        'mnli': ['premise', 'hypothesis'],\n",
        "        'qnli': ['question', 'sentence'],\n",
        "        'rte': ['sentence1', 'sentence2'],\n",
        "        'wnli': ['sentence1', 'sentence2'],\n",
        "        'ax': ['premise', 'hypothesis']\n",
        "    }\n",
        "\n",
        "    glue_task_num_labels = {\n",
        "        'cola': 2,\n",
        "        'sst2': 2,\n",
        "        'mrpc': 2,\n",
        "        'qqp': 2,\n",
        "        'stsb': 1,\n",
        "        'mnli': 3,\n",
        "        'qnli': 2,\n",
        "        'rte': 2,\n",
        "        'wnli': 2,\n",
        "        'ax': 3\n",
        "    }\n",
        "\n",
        "    loader_columns = [\n",
        "        'datasets_idx',\n",
        "        'input_ids',\n",
        "        'token_type_ids',\n",
        "        'attention_mask',\n",
        "        'start_positions',\n",
        "        'end_positions',\n",
        "        'labels'\n",
        "    ]\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        model_name_or_path: str,\n",
        "        task_name: str ='mrpc',\n",
        "        max_seq_length: int = 128,\n",
        "        train_batch_size: int = 32,\n",
        "        eval_batch_size: int = 32,\n",
        "        **kwargs\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.model_name_or_path = model_name_or_path\n",
        "        self.task_name = task_name\n",
        "        self.max_seq_length = max_seq_length\n",
        "        self.train_batch_size = train_batch_size\n",
        "        self.eval_batch_size = eval_batch_size\n",
        "\n",
        "        self.text_fields = self.task_text_field_map[task_name]\n",
        "        self.num_labels = self.glue_task_num_labels[task_name]\n",
        "        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
        "\n",
        "    def setup(self, stage):\n",
        "        self.dataset = datasets.load_dataset('glue', self.task_name)\n",
        "\n",
        "        for split in self.dataset.keys():\n",
        "            self.dataset[split] = self.dataset[split].map(\n",
        "                self.convert_to_features,\n",
        "                batched=True,\n",
        "                remove_columns=['label'],\n",
        "            )\n",
        "            self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n",
        "            self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n",
        "\n",
        "        self.eval_splits = [x for x in self.dataset.keys() if 'validation' in x]\n",
        "\n",
        "    def prepare_data(self):\n",
        "        datasets.load_dataset('glue', self.task_name)\n",
        "        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n",
        "    \n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)\n",
        "    \n",
        "    def val_dataloader(self):\n",
        "        if len(self.eval_splits) == 1:\n",
        "            return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)\n",
        "        elif len(self.eval_splits) > 1:\n",
        "            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
        "\n",
        "    def test_dataloader(self):\n",
        "        if len(self.eval_splits) == 1:\n",
        "            return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)\n",
        "        elif len(self.eval_splits) > 1:\n",
        "            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n",
        "\n",
        "    def convert_to_features(self, example_batch, indices=None):\n",
        "\n",
        "        # Either encode single sentence or sentence pairs\n",
        "        if len(self.text_fields) > 1:\n",
        "            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n",
        "        else:\n",
        "            texts_or_text_pairs = example_batch[self.text_fields[0]]\n",
        "\n",
        "        # Tokenize the text/text pairs\n",
        "        features = self.tokenizer.batch_encode_plus(\n",
        "            texts_or_text_pairs,\n",
        "            max_length=self.max_seq_length,\n",
        "            pad_to_max_length=True,\n",
        "            truncation=True\n",
        "        )\n",
        "\n",
        "        # Rename label to labels to make it easier to pass to model forward\n",
        "        features['labels'] = example_batch['label']\n",
        "\n",
        "        return features"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQC3a6KuOpX3",
        "colab_type": "text"
      },
      "source": [
        "#### You could use this datamodule with standalone PyTorch if you wanted..."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JCMH3IAsNffF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dm = GLUEDataModule('distilbert-base-uncased')\n",
        "dm.prepare_data()\n",
        "dm.setup('fit')\n",
        "next(iter(dm.train_dataloader()))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l9fQ_67BO2Lj",
        "colab_type": "text"
      },
      "source": [
        "## GLUE Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gtn5YGKYO65B",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class GLUETransformer(pl.LightningModule):\n",
        "    def __init__(\n",
        "        self,\n",
        "        model_name_or_path: str,\n",
        "        num_labels: int,\n",
        "        learning_rate: float = 2e-5,\n",
        "        adam_epsilon: float = 1e-8,\n",
        "        warmup_steps: int = 0,\n",
        "        weight_decay: float = 0.0,\n",
        "        train_batch_size: int = 32,\n",
        "        eval_batch_size: int = 32,\n",
        "        eval_splits: Optional[list] = None,\n",
        "        **kwargs\n",
        "    ):\n",
        "        super().__init__()\n",
        "\n",
        "        self.save_hyperparameters()\n",
        "\n",
        "        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)\n",
        "        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)\n",
        "        self.metric = datasets.load_metric(\n",
        "            'glue',\n",
        "            self.hparams.task_name,\n",
        "            experiment_id=datetime.now().strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
        "        )\n",
        "\n",
        "    def forward(self, **inputs):\n",
        "        return self.model(**inputs)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        outputs = self(**batch)\n",
        "        loss = outputs[0]\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
        "        outputs = self(**batch)\n",
        "        val_loss, logits = outputs[:2]\n",
        "\n",
        "        if self.hparams.num_labels >= 1:\n",
        "            preds = torch.argmax(logits, axis=1)\n",
        "        elif self.hparams.num_labels == 1:\n",
        "            preds = logits.squeeze()\n",
        "\n",
        "        labels = batch[\"labels\"]\n",
        "\n",
        "        return {'loss': val_loss, \"preds\": preds, \"labels\": labels}\n",
        "\n",
        "    def validation_epoch_end(self, outputs):\n",
        "        if self.hparams.task_name == 'mnli':\n",
        "            for i, output in enumerate(outputs):\n",
        "                # matched or mismatched\n",
        "                split = self.hparams.eval_splits[i].split('_')[-1]\n",
        "                preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
        "                labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
        "                loss = torch.stack([x['loss'] for x in output]).mean()\n",
        "                self.log(f'val_loss_{split}', loss, prog_bar=True)\n",
        "                split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
        "                self.log_dict(split_metrics, prog_bar=True)\n",
        "            return loss\n",
        "\n",
        "        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
        "        labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
        "        loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
        "        self.log('val_loss', loss, prog_bar=True)\n",
        "        self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
        "        return loss\n",
        "\n",
        "    def setup(self, stage):\n",
        "        if stage == 'fit':\n",
        "            # Get dataloader by calling it - train_dataloader() is called after setup() by default\n",
        "            train_loader = self.train_dataloader()\n",
        "\n",
        "            # Calculate total steps\n",
        "            self.total_steps = (\n",
        "                (len(train_loader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))\n",
        "                // self.hparams.accumulate_grad_batches\n",
        "                * float(self.hparams.max_epochs)\n",
        "            )\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        \"Prepare optimizer and schedule (linear warmup and decay)\"\n",
        "        model = self.model\n",
        "        no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
        "        optimizer_grouped_parameters = [\n",
        "            {\n",
        "                \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
        "                \"weight_decay\": self.hparams.weight_decay,\n",
        "            },\n",
        "            {\n",
        "                \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
        "                \"weight_decay\": 0.0,\n",
        "            },\n",
        "        ]\n",
        "        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n",
        "\n",
        "        scheduler = get_linear_schedule_with_warmup(\n",
        "            optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps\n",
        "        )\n",
        "        scheduler = {\n",
        "            'scheduler': scheduler,\n",
        "            'interval': 'step',\n",
        "            'frequency': 1\n",
        "        }\n",
        "        return [optimizer], [scheduler]\n",
        "\n",
        "    @staticmethod\n",
        "    def add_model_specific_args(parent_parser):\n",
        "        parser = ArgumentParser(parents=[parent_parser], add_help=False)\n",
        "        parser.add_argument(\"--learning_rate\", default=2e-5, type=float)\n",
        "        parser.add_argument(\"--adam_epsilon\", default=1e-8, type=float)\n",
        "        parser.add_argument(\"--warmup_steps\", default=0, type=int)\n",
        "        parser.add_argument(\"--weight_decay\", default=0.0, type=float)\n",
        "        return parser"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ha-NdIP_xbd3",
        "colab_type": "text"
      },
      "source": [
        "### ⚡ Quick Tip \n",
        "  - Combine arguments from your DataModule, Model, and Trainer into one for easy and robust configuration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3dEHnl3RPlAR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def parse_args(args=None):\n",
        "    parser = ArgumentParser()\n",
        "    parser = pl.Trainer.add_argparse_args(parser)\n",
        "    parser = GLUEDataModule.add_argparse_args(parser)\n",
        "    parser = GLUETransformer.add_model_specific_args(parser)\n",
        "    parser.add_argument('--seed', type=int, default=42)\n",
        "    return parser.parse_args(args)\n",
        "\n",
        "\n",
        "def main(args):\n",
        "    pl.seed_everything(args.seed)\n",
        "    dm = GLUEDataModule.from_argparse_args(args)\n",
        "    dm.prepare_data()\n",
        "    dm.setup('fit')\n",
        "    model = GLUETransformer(num_labels=dm.num_labels, eval_splits=dm.eval_splits, **vars(args))\n",
        "    trainer = pl.Trainer.from_argparse_args(args)\n",
        "    return dm, model, trainer"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PkuLaeec3sJ-",
        "colab_type": "text"
      },
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QSpueK5UPsN7",
        "colab_type": "text"
      },
      "source": [
        "## CoLA\n",
        "\n",
        "See an interactive view of the CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NJnFmtpnPu0Y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "mocked_args = \"\"\"\n",
        "    --model_name_or_path albert-base-v2\n",
        "    --task_name cola\n",
        "    --max_epochs 3\n",
        "    --gpus 1\"\"\".split()\n",
        "\n",
        "args = parse_args(mocked_args)\n",
        "dm, model, trainer = main(args)\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_MrNsTnqdz4z",
        "colab_type": "text"
      },
      "source": [
        "## MRPC\n",
        "\n",
        "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mrpc)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LBwRxg9Cb3d-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "mocked_args = \"\"\"\n",
        "    --model_name_or_path distilbert-base-cased\n",
        "    --task_name mrpc\n",
        "    --max_epochs 3\n",
        "    --gpus 1\"\"\".split()\n",
        "\n",
        "args = parse_args(mocked_args)\n",
        "dm, model, trainer = main(args)\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iZhbn0HzfdCu",
        "colab_type": "text"
      },
      "source": [
        "## MNLI\n",
        "\n",
        " - The MNLI dataset is huge, so we aren't going to bother trying to train it here.\n",
        "\n",
        " - Let's just make sure our multi-dataloader logic is right by skipping over training and going straight to validation.\n",
        "\n",
        "See an interactive view of the MRPC dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=mnli)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AvsZMOggfcWW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "mocked_args = \"\"\"\n",
        "    --model_name_or_path distilbert-base-uncased\n",
        "    --task_name mnli\n",
        "    --max_epochs 1\n",
        "    --gpus 1\n",
        "    --limit_train_batches 10\n",
        "    --progress_bar_refresh_rate 20\"\"\".split()\n",
        "\n",
        "args = parse_args(mocked_args)\n",
        "dm, model, trainer = main(args)\n",
        "trainer.fit(model, dm)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}