{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "506b76f3-6de7-4972-9447-800d728b9b5f",
   "metadata": {},
   "source": [
    "# How to fine-tune a T5 model with ONNX Runtime\n",
    "\n",
    "This notebook is largely inspired by the summarization [notebook of Transformers](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb) which takes PyTorch as backend for fine tuning.\n",
    "\n",
    "Here you will use the `ORTSeq2SeqTrainer` class in [Optimum](https://github.com/huggingface/optimum) library and take [ONNX Runtime](https://microsoft.github.io/onnxruntime/) as backend to accelerate the training. \n",
    "\n",
    "\n",
    "In this notebook, we will walk through the fine-tuning of [T5-small](https://huggingface.co/docs/transformers/model_doc/t5) model in the 🤗 Transformers for a summarization task. We will use the [XSum](https://arxiv.org/pdf/1808.08745.pdf) dataset (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries, and the training as well as inference will be done by leveraging `ORTSeq2SeqTrainer` in Optimum! \n",
    "\n",
    "Let's speed the training up!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "992a2569-5a97-42b4-bfae-0eb36866c033",
   "metadata": {},
   "source": [
    "__Dependencies__\n",
    "\n",
    "To use ONNX Runtime for training, you need a machine with at least one NVIDIA GPU.\n",
    "\n",
    "__ONNX Runtime training module need to be properly installed before launching the notebook! Please follow the instruction in [Optimum's documentation](https://huggingface.co/docs/optimum/onnxruntime/trainer) to set up your environment.__\n",
    "\n",
    "Check your GPU:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd2269d9-cec3-4495-ace0-5c16fef92a05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Sep 16 19:04:38 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 11.3     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla T4            On   | 00000000:00:1E.0 Off |                    0 |\n",
      "| N/A   43C    P0    25W /  70W |   3420MiB / 15109MiB |      0%      Default |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                       GPU Memory |\n",
      "|  GPU       PID   Type   Process name                             Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0736bfe3-439d-42df-8477-231428662e0e",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install 🤗 Optimum, 🤗 Transformers, 🤗 Datasets and 🤗 evaluate. Uncomment the following cell and run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4317fec7-6156-49f3-960a-40169eea60ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install optimum transformers datasets evaluate rouge-score nltk tokenizers>=0.11.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "430af801-182a-4926-85f0-80b3f502b904",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
      "[nltk_data]   Unzipping tokenizers/punkt.zip.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import nltk\n",
    "nltk.download(\"punkt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "325cce71-b7b2-469c-a5fd-991b7b529864",
   "metadata": {},
   "source": [
    "__[Optional]__ If you want to share your model with the community and generate an inference API, there are a few more steps to follow.\n",
    "\n",
    "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/welcome) if you haven't already!) then execute the following cell and input your username and password:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2bffacff-c201-4d79-bbdc-c470f76fb618",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a15538e-d956-460c-851d-1a9ef4e1e8da",
   "metadata": {},
   "source": [
    "Then you need to install Git-LFS. Uncomment the following instructions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a15b4795-7da1-4bc2-8cac-67edcedd7662",
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5af4aed-5277-4271-9663-4bd23a3c51de",
   "metadata": {},
   "source": [
    "Make sure your version of Transformers is at least 4.15.0:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1331aec7-c0de-46c8-9b14-e44c7417b188",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.23.0.dev0\n"
     ]
    }
   ],
   "source": [
    "import transformers\n",
    "\n",
    "print(transformers.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7caff53-f8b7-4423-9daf-0a4cce76050e",
   "metadata": {},
   "source": [
    "__Setup__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "88e4b671-3cad-4a4a-8cfe-0612e297d5b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"t5-small\"\n",
    "task = \"xsum\"\n",
    "metric_name = \"rouge\"\n",
    "batch_size = 8\n",
    "learning_rate=2e-5\n",
    "weight_decay = 0.01\n",
    "num_train_epochs = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5f27b85",
   "metadata": {},
   "source": [
    "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e6ef84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.utils import send_example_telemetry\n",
    "\n",
    "send_example_telemetry(\"summarization_notebook_ort\", framework=\"none\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec6e4df1-2d90-478b-bdf2-1e6507851d4a",
   "metadata": {},
   "source": [
    "## Loading the dataset\n",
    "\n",
    "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0cd399ad-2ef2-4d6e-9fc5-daaabef52b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "raw_datasets = load_dataset(task)\n",
    "metric = load_metric(metric_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e904992d-16c8-44c0-8998-2828661eb90e",
   "metadata": {},
   "source": [
    "__[Optional]__ To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6b98148c-8da3-4535-81d0-a532bef04a58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import random\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def show_random_elements(dataset, num_examples=1):\n",
    "    assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
    "    picks = []\n",
    "    for _ in range(num_examples):\n",
    "        pick = random.randint(0, len(dataset)-1)\n",
    "        while pick in picks:\n",
    "            pick = random.randint(0, len(dataset)-1)\n",
    "        picks.append(pick)\n",
    "    \n",
    "    df = pd.DataFrame(dataset[picks])\n",
    "    for column, typ in dataset.features.items():\n",
    "        if isinstance(typ, datasets.ClassLabel):\n",
    "            df[column] = df[column].transform(lambda i: typ.names[i])\n",
    "    display(HTML(df.to_html()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "02392d62-56b7-4704-bc17-e28904739664",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>document</th>\n",
       "      <th>summary</th>\n",
       "      <th>id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Lydiate's ruptured knee ligaments suffered in Wales' win over South Africa mean he will also miss out on the chance to challenge for a place on the 2017 Lions tour to New Zealand.\\n\"I have absolutely no doubt he'll get back to where he was,\" said Davis.\\n\"So much of it comes down to the determination of the player.\"\\nThe Australian defence coach, a product of rugby league, likens Lydiate's chances of returning to the full capacity of his fitness to the experiences of former South Africa and Bath fly-half Butch James.\\nDavis coached James at Bath before the player returned to South Africa, for whom he made his final Test appearance against Wales at the 2011 World Cup.\\n\"I worked with Butch James at Bath for three years and I think he'd had four knee reconstructions,\" Davis said.\\n\"Each time he was able to come back, still represent South Africa and still play fantastic rugby for Bath and in South Africa.\\n\"Most rugby players are competitive beasts. Most international players have got there because they are determined and love playing the game.\\n\"Dan is going to be no different. He's going to apply himself to his rehab as he would to a training field.\\n\"We have excellent medical staff here who will support him. I'm 100% sure he's going to come back good, fit and firing.\"\\nDavis said Lydiate was \"getting back to some of his top form\" when he was injured.\\n\"We are going to give the best support and the best treatment and try to get him back on the park as quickly as possible,\" he said.\\n\"Saying that, we have a pretty decent back-row roster that we can pick from.\"\\nOspreys travel to French Top 14 side Grenoble in the European Challenge Cup third round on Thursday looking for a third successive win in the competition.\\nDavis said he was unaware of Wales hooker Scott Baldwin's progress amid Head Injury Assessment protocols after a blow he suffered in their win over Edinburgh last Friday.</td>\n",
       "      <td>Wales flanker Dan Lydiate can make a full recovery from the injury that has ended his season, says Ospreys defence coach Brad Davis.</td>\n",
       "      <td>38225298</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_random_elements(raw_datasets[\"train\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21195faa-5d1a-4f85-a334-1be61819be0d",
   "metadata": {},
   "source": [
    "The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e39e3b3b-b8b0-46f1-a4e9-22162d585324",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Metric(name: \"rouge\", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: \"\"\"\n",
       "Calculates average rouge scores for a list of hypotheses and references\n",
       "Args:\n",
       "    predictions: list of predictions to score. Each prediction\n",
       "        should be a string with tokens separated by spaces.\n",
       "    references: list of reference for each prediction. Each\n",
       "        reference should be a string with tokens separated by spaces.\n",
       "    rouge_types: A list of rouge types to calculate.\n",
       "        Valid names:\n",
       "        `\"rouge{n}\"` (e.g. `\"rouge1\"`, `\"rouge2\"`) where: {n} is the n-gram based scoring,\n",
       "        `\"rougeL\"`: Longest common subsequence based scoring.\n",
       "        `\"rougeLSum\"`: rougeLsum splits text using `\"\n",
       "\"`.\n",
       "        See details in https://github.com/huggingface/datasets/issues/617\n",
       "    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.\n",
       "    use_aggregator: Return aggregates if this is set to True\n",
       "Returns:\n",
       "    rouge1: rouge_1 (precision, recall, f1),\n",
       "    rouge2: rouge_2 (precision, recall, f1),\n",
       "    rougeL: rouge_l (precision, recall, f1),\n",
       "    rougeLsum: rouge_lsum (precision, recall, f1)\n",
       "Examples:\n",
       "\n",
       "    >>> rouge = datasets.load_metric('rouge')\n",
       "    >>> predictions = [\"hello there\", \"general kenobi\"]\n",
       "    >>> references = [\"hello there\", \"general kenobi\"]\n",
       "    >>> results = rouge.compute(predictions=predictions, references=references)\n",
       "    >>> print(list(results.keys()))\n",
       "    ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']\n",
       "    >>> print(results[\"rouge1\"])\n",
       "    AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))\n",
       "    >>> print(results[\"rouge1\"].mid.fmeasure)\n",
       "    1.0\n",
       "\"\"\", stored examples: 0)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "356db537-7aca-4007-8abc-93d1d1e66023",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
       " 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
       " 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n",
       " 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fake_preds = [\"hello there\", \"general kenobi\"]\n",
    "fake_labels = [\"hello there\", \"general kenobi\"]\n",
    "metric.compute(predictions=fake_preds, references=fake_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce8654e-7b89-4939-9492-0dc41748a8d5",
   "metadata": {},
   "source": [
    "## Preprocessing the data\n",
    "\n",
    "Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.\n",
    "\n",
    "To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
    "\n",
    "* we get a tokenizer that corresponds to the model architecture we want to use,\n",
    "* we download the vocabulary used when pretraining this specific checkpoint.\n",
    "\n",
    "That vocabulary will be cached, so it's not downloaded again the next time we run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ed0dd242-168d-4a23-9c31-9a771447b0d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/transformers/models/t5/tokenization_t5_fast.py:156: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
      "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
      "- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.\n",
      "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
      "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "    \n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcfff705-6898-4416-a70d-ab38574e4ae4",
   "metadata": {},
   "source": [
    "To prepare the targets for our model, we need to tokenize them inside the as_target_tokenizer context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ab1dcab0-2058-4941-8e00-6632246fabb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/transformers/tokenization_utils_base.py:3540: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "with tokenizer.as_target_tokenizer():\n",
    "    print(tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bec3614-7d7e-4032-b5da-16cb729afb38",
   "metadata": {},
   "source": [
    "If you are using one of the five T5 checkpoints we have to prefix the inputs with \"summarize:\" (the model can also translate and it needs the prefix to know which task it has to perform)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "206bf42d-bbe8-4522-993b-172a384ac77e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if model_checkpoint in [\"t5-small\", \"t5-base\", \"t5-large\", \"t5-3b\", \"t5-11b\"]:\n",
    "    prefix = \"summarize: \"\n",
    "else:\n",
    "    prefix = \"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "772992e6-3b36-4a22-8176-6d55c02805d0",
   "metadata": {},
   "source": [
    "We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "bb805008-5448-4b3a-b5dd-af158267290b",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_input_length = 1024\n",
    "max_target_length = 128\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = [prefix + doc for doc in examples[\"document\"]]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
    "\n",
    "    # Setup the tokenizer for targets\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(examples[\"summary\"], max_length=max_target_length, truncation=True)\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7fab1f9-3770-4c8b-9464-ba6142da6895",
   "metadata": {},
   "source": [
    "To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "51a27c6a-7fc8-429c-9220-64052fd0dd98",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ab34167-4ce6-4aea-aa17-65d74dabcb5d",
   "metadata": {},
   "source": [
    "Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.\n",
    "\n",
    "Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c1e75c-4247-4f58-bd1d-cb70f5629172",
   "metadata": {},
   "source": [
    "## Fine-tuning the model\n",
    "\n",
    "Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class to fist load the PyTorch model. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b0331d86-4daa-4b9f-85f8-fafdab53f457",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq\n",
    "from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea682a86-6c45-458f-852c-492b653352ec",
   "metadata": {},
   "source": [
    "Note that we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.\n",
    "\n",
    "To instantiate a `ORTSeq2SeqTrainer`, we will need to define three more things. The most important is the [`ORTSeq2SeqTrainingArguments`](https://huggingface.co/docs/optimum/onnxruntime/trainer#optimum.onnxruntime.ORTSeq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. You can also use `Seq2SeqTrainingArguments` in Transformers, but `ORTSeq2SeqTrainingArguments` enables more optimized features of ONNX Runtime. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "74fc3b58-db3b-4e03-9b26-6b9dcfb46d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "args = ORTSeq2SeqTrainingArguments(\n",
    "    f\"{model_name}-finetuned-xsum\",\n",
    "    evaluation_strategy = \"epoch\",\n",
    "    learning_rate=learning_rate,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    weight_decay=weight_decay,\n",
    "    save_total_limit=3,\n",
    "    num_train_epochs=num_train_epochs,\n",
    "    predict_with_generate=True,\n",
    "    optim=\"adamw_ort_fused\",\n",
    "    # push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3d56ba1-1399-4b00-a175-a1c3a6ff8feb",
   "metadata": {},
   "source": [
    "Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the cell and customize the weight decay. Since the `ORTSeq2SeqTrainer` will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the `predict_with_generate` option (to properly generate summaries) and activate mixed precision training (to go a bit faster).\n",
    "\n",
    "\n",
    "The last argument to setup everything so we can push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the hub_model_id argument to set the repo name (it needs to be the full name, including your namespace: for instance `\"optimum/t5-large-finetuned-xsum\"`).\n",
    "\n",
    "Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fb644a2f-0f18-4646-916d-b4c6b5c4ccf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForSeq2Seq(\n",
    "    tokenizer,\n",
    "    model=model,\n",
    "    label_pad_token_id=tokenizer.pad_token_id,\n",
    "    pad_to_multiple_of=8 if args.fp16 else None,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15b73b4b-3816-4bba-a325-f75e3ec29154",
   "metadata": {},
   "source": [
    "The last thing to define for our `ORTSeq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0266a795-ddf8-46e0-8d37-be8863cb62cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
    "    # Replace -100 in the labels as we can't decode them.\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    # Rouge expects a newline after each sentence\n",
    "    decoded_preds = [\"\\n\".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
    "    decoded_labels = [\"\\n\".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]\n",
    "\n",
    "    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n",
    "    # Extract a few results\n",
    "    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
    "\n",
    "    # Add mean generated length\n",
    "    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]\n",
    "    result[\"gen_len\"] = np.mean(prediction_lens)\n",
    "\n",
    "    return {k: round(v, 4) for k, v in result.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc60c851-adaf-43bc-8e0e-ce6fbbdd1e9f",
   "metadata": {},
   "source": [
    "Then we just need to pass all of this along with our datasets to the `ORTSeq2SeqTrainer`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ba3192aa-5703-467b-b168-054407c80d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = ORTSeq2SeqTrainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics if args.predict_with_generate else None,\n",
    "    feature=\"seq2seq-lm\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8281d5c7-0a9e-4df0-bd99-80a1ccbf809d",
   "metadata": {},
   "source": [
    "We can now finetune our model by just calling the `train` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8c9f8bfb-1a47-4b33-a632-50170f3817d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document. If summary, id, document are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.\n",
      "You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py:191: UserWarning: Fast path enabled - skipping checks. Rebuild graph: True, Execution agent: True, Device check: True\n",
      "  warnings.warn(\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "/usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_logger.py:52: UserWarning: There were one or more warnings or errors raised while exporting the PyTorch model. Please enable INFO level logging to view all warnings and errors.\n",
      "  warnings.warn(\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='25506' max='25506' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [25506/25506 3:24:00, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Rouge1</th>\n",
       "      <th>Rouge2</th>\n",
       "      <th>Rougel</th>\n",
       "      <th>Rougelsum</th>\n",
       "      <th>Gen Len</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.952500</td>\n",
       "      <td>1.782154</td>\n",
       "      <td>28.623200</td>\n",
       "      <td>7.974700</td>\n",
       "      <td>22.526900</td>\n",
       "      <td>22.527000</td>\n",
       "      <td>18.810800</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-500/spiece.model\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-1000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-1000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-1000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-1000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-1000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-1000/spiece.model\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-1500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-1500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-1500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-1500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-1500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-1500/spiece.model\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-2000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-2000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-2000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-2000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-2000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-2000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-2500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-2500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-2500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-2500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-2500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-2500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-1000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-3000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-3000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-3000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-3000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-3000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-3000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-1500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-3500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-3500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-3500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-3500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-3500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-3500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-2000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-4000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-4000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-4000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-4000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-4000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-4000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-2500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-4500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-4500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-4500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-4500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-4500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-4500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-3000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-5000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-5000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-5000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-5000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-5000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-5000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-3500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-5500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-5500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-5500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-5500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-5500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-5500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-4000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-6000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-6000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-6000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-6000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-6000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-6000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-4500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-6500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-6500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-6500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-6500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-6500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-6500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-5000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-7000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-7000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-7000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-7000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-7000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-7000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-5500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-7500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-7500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-7500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-7500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-7500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-7500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-6000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-8000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-8000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-8000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-8000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-8000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-8000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-6500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-8500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-8500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-8500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-8500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-8500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-8500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-7000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-9000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-9000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-9000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-9000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-9000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-9000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-7500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-9500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-9500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-9500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-9500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-9500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-9500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-8000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-10000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-10000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-10000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-10000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-10000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-10000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-8500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-10500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-10500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-10500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-10500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-10500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-10500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-9000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-11000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-11000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-11000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-11000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-11000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-11000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-9500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-11500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-11500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-11500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-11500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-11500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-11500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-10000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-12000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-12000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-12000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-12000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-12000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-12000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-10500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-12500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-12500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-12500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-12500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-12500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-12500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-11000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-13000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-13000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-13000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-13000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-13000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-13000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-11500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-13500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-13500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-13500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-13500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-13500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-13500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-12000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-14000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-14000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-14000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-14000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-14000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-14000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-12500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-14500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-14500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-14500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-14500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-14500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-14500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-13000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-15000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-15000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-15000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-15000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-15000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-15000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-13500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-15500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-15500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-15500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-15500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-15500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-15500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-14000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-16000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-16000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-16000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-16000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-16000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-16000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-14500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-16500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-16500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-16500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-16500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-16500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-16500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-15000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-17000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-17000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-17000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-17000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-17000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-17000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-15500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-17500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-17500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-17500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-17500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-17500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-17500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-16000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-18000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-18000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-18000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-18000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-18000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-18000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-16500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-18500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-18500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-18500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-18500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-18500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-18500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-17000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-19000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-19000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-19000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-19000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-19000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-19000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-17500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-19500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-19500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-19500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-19500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-19500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-19500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-18000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-20000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-20000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-20000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-20000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-20000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-20000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-18500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-20500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-20500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-20500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-20500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-20500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-20500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-19000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-21000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-21000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-21000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-21000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-21000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-21000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-19500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-21500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-21500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-21500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-21500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-21500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-21500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-20000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-22000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-22000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-22000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-22000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-22000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-22000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-20500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-22500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-22500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-22500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-22500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-22500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-22500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-21000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-23000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-23000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-23000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-23000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-23000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-23000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-21500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-23500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-23500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-23500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-23500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-23500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-23500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-22000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-24000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-24000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-24000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-24000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-24000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-24000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-22500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-24500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-24500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-24500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-24500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-24500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-24500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-23000] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-25000\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-25000/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-25000/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-25000/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-25000/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-25000/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-23500] due to args.save_total_limit\n",
      "Saving model checkpoint to t5-small-finetuned-xsum/checkpoint-25500\n",
      "Configuration saved in t5-small-finetuned-xsum/checkpoint-25500/config.json\n",
      "Model weights saved in t5-small-finetuned-xsum/checkpoint-25500/pytorch_model.bin\n",
      "tokenizer config file saved in t5-small-finetuned-xsum/checkpoint-25500/tokenizer_config.json\n",
      "Special tokens file saved in t5-small-finetuned-xsum/checkpoint-25500/special_tokens_map.json\n",
      "Copy vocab file to t5-small-finetuned-xsum/checkpoint-25500/spiece.model\n",
      "Deleting older checkpoint [t5-small-finetuned-xsum/checkpoint-24000] due to args.save_total_limit\n",
      "The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document. If summary, id, document are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.\n",
      "WARNING:optimum.onnxruntime.trainer:[INFO] Evaluating with PyTorch backend. If you want to use ONNX Runtime for the evaluation, set `trainer.evaluate(inference_with_ort=True)`.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 11332\n",
      "  Batch size = 8\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=25506, training_loss=2.008956053654633, metrics={'train_runtime': 12251.2493, 'train_samples_per_second': 16.655, 'train_steps_per_second': 2.082, 'total_flos': 5.03014392471552e+16, 'train_loss': 2.008956053654633, 'epoch': 1.0})"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00196cb8-084f-491b-b474-5773014e4f10",
   "metadata": {},
   "source": [
    "You can now upload the result of the training to the Hub, just execute this instruction:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10f5e7fc-2dbd-40e5-978e-f28b527ecc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10cf3235-73c3-4800-8762-76a9e87095c1",
   "metadata": {},
   "source": [
    "You will also be able to save your fine-tuned model as PyTorch or ONNX model in the `output_dir` that you set in `ORTSeq2SeqTrainer`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "949a4ead-07be-4300-912c-2134645817dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4408a857-9d80-48af-8ebe-e8523cfc2bf6",
   "metadata": {},
   "source": [
    "## Evaluating your model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0f96b4c-6f90-471e-a731-a26c3d080d64",
   "metadata": {},
   "source": [
    "Evaluate the performance of the model that you just fine-tuned with the validation dataset that you've passed to `ORTSeq2SeqTrainer` by just calling the `evaluate` method. \n",
    "\n",
    "If you set `inference_with_ort=True`, the inference will be done with ONNX Runtime backend. Otherwise, the inference will take PyTorch as backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4b6cbcf4-d672-4500-a9f9-06fb5fb93495",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the evaluation set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: summary, id, document. If summary, id, document are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.\n",
      "Using framework PyTorch: 1.11.0+cu113\n",
      "Overriding 1 configuration item(s)\n",
      "\t- use_cache -> False\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Using framework PyTorch: 1.11.0+cu113\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Overriding 1 configuration item(s)\n",
      "\t- use_cache -> False\n",
      "/usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:701: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if causal_mask.shape[1] < attention_mask.shape[1]:\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Using framework PyTorch: 1.11.0+cu113\n",
      "Overriding 1 configuration item(s)\n",
      "\t- use_cache -> True\n",
      "In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode\n",
      "In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "WARNING: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "WARNING:optimum.modeling_base:config.json NOT FOUND in HuggingFace Hub\n",
      "Model config PretrainedConfig {\n",
      "  \"transformers_version\": \"4.23.0.dev0\"\n",
      "}\n",
      "\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "2022-09-16 23:10:03.897260566 [W:onnxruntime:, graph.cc:106 MergeShapeInfo] Error merging shape info for output. 'loss' source:{} target:{-1,-1}. Falling back to lenient merge.\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "Warning: Checker does not support models with experimental ops: ATen\n",
      "2022-09-16 23:10:04.342904389 [W:onnxruntime:, graph.cc:106 MergeShapeInfo] Error merging shape info for output. 'loss' source:{} target:{-1,-1}. Falling back to lenient merge.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1417' max='1417' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1417/1417 37:53]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 1.782148003578186,\n",
       " 'eval_rouge1': 28.6232,\n",
       " 'eval_rouge2': 7.9747,\n",
       " 'eval_rougeL': 22.5269,\n",
       " 'eval_rougeLsum': 22.527,\n",
       " 'eval_gen_len': 18.8108,\n",
       " 'eval_runtime': 2303.9032,\n",
       " 'eval_samples_per_second': 4.919,\n",
       " 'eval_steps_per_second': 0.615,\n",
       " 'epoch': 1.0}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(inference_with_ort=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f5e9a38-156d-45ce-8ddc-c06c76bef057",
   "metadata": {},
   "source": [
    "## __Extended reading__\n",
    "\n",
    "Now check your trained ONNX model with [Netron](https://netron.app/), and you might notice that the computation graph is under optimizatiom. Want to accelerate even more? \n",
    "\n",
    "Check the [graph optimizers](https://huggingface.co/docs/optimum/onnxruntime/optimization) and [quantizers](https://huggingface.co/docs/optimum/onnxruntime/quantization) of [Optimum](https://github.com/huggingface/optimum)🤗! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be498d7f-8a14-46f7-a0eb-2b9543b5a3a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}