{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-10-22T21:58:13.927758Z",
     "iopub.status.busy": "2022-10-22T21:58:13.927382Z",
     "iopub.status.idle": "2022-10-22T21:58:16.109866Z",
     "shell.execute_reply": "2022-10-22T21:58:16.108310Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: tokenizer in /usr/local/lib/python3.9/dist-packages (3.4.2)\n",
      "Collecting datasets\n",
      "  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m451.7/451.7 KB\u001B[0m \u001B[31m34.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: sentencepiece in /usr/local/lib/python3.9/dist-packages (0.1.97)\n",
      "Requirement already satisfied: protobuf==3.20.0 in /usr/local/lib/python3.9/dist-packages (3.20.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from datasets) (1.23.5)\n",
      "Collecting fsspec[http]>=2021.11.1\n",
      "  Downloading fsspec-2022.11.0-py3-none-any.whl (139 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m139.5/139.5 KB\u001B[0m \u001B[31m38.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (4.64.1)\n",
      "Collecting xxhash\n",
      "  Downloading xxhash-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m212.0/212.0 KB\u001B[0m \u001B[31m46.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hCollecting multiprocess\n",
      "  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m132.9/132.9 KB\u001B[0m \u001B[31m37.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hCollecting pyarrow>=6.0.0\n",
      "  Downloading pyarrow-10.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.9 MB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m35.9/35.9 MB\u001B[0m \u001B[31m71.6 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\n",
      "\u001B[?25hCollecting pandas\n",
      "  Downloading pandas-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.2 MB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m12.2/12.2 MB\u001B[0m \u001B[31m127.0 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m00:01\u001B[0m00:01\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets) (6.0)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets) (22.0)\n",
      "Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.9/dist-packages (from datasets) (0.11.1)\n",
      "Collecting aiohttp\n",
      "  Downloading aiohttp-3.8.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.0/1.0 MB\u001B[0m \u001B[31m84.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hCollecting dill<0.3.7\n",
      "  Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m110.5/110.5 KB\u001B[0m \u001B[31m27.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: requests>=2.19.0 in /usr/lib/python3/dist-packages (from datasets) (2.22.0)\n",
      "Collecting responses<0.19\n",
      "  Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n",
      "Collecting aiosignal>=1.1.2\n",
      "  Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n",
      "Collecting yarl<2.0,>=1.0\n",
      "  Downloading yarl-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (264 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m264.6/264.6 KB\u001B[0m \u001B[31m62.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hCollecting frozenlist>=1.1.1\n",
      "  Downloading frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (158 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m158.8/158.8 KB\u001B[0m \u001B[31m31.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets) (22.1.0)\n",
      "Collecting async-timeout<5.0,>=4.0.0a3\n",
      "  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n",
      "Collecting charset-normalizer<3.0,>=2.0\n",
      "  Downloading charset_normalizer-2.1.1-py3-none-any.whl (39 kB)\n",
      "Collecting multidict<7.0,>=4.5\n",
      "  Downloading multidict-6.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m114.2/114.2 KB\u001B[0m \u001B[31m28.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.8.2)\n",
      "Collecting urllib3>=1.25.10\n",
      "  Downloading urllib3-1.26.13-py2.py3-none-any.whl (140 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m140.6/140.6 KB\u001B[0m \u001B[31m21.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->datasets) (2.8.2)\n",
      "Collecting pytz>=2020.1\n",
      "  Downloading pytz-2022.6-py2.py3-none-any.whl (498 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m498.1/498.1 KB\u001B[0m \u001B[31m55.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hRequirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.14.0)\n",
      "Requirement already satisfied: idna>=2.0 in /usr/lib/python3/dist-packages (from yarl<2.0,>=1.0->aiohttp->datasets) (2.8)\n",
      "Installing collected packages: pytz, xxhash, urllib3, pyarrow, multidict, fsspec, frozenlist, dill, charset-normalizer, async-timeout, yarl, responses, pandas, multiprocess, aiosignal, aiohttp, datasets\n",
      "  Attempting uninstall: urllib3\n",
      "    Found existing installation: urllib3 1.25.8\n",
      "    Uninstalling urllib3-1.25.8:\n",
      "      Successfully uninstalled urllib3-1.25.8\n",
      "Successfully installed aiohttp-3.8.3 aiosignal-1.3.1 async-timeout-4.0.2 charset-normalizer-2.1.1 datasets-2.7.1 dill-0.3.6 frozenlist-1.3.3 fsspec-2022.11.0 multidict-6.0.3 multiprocess-0.70.14 pandas-1.5.2 pyarrow-10.0.1 pytz-2022.6 responses-0.18.0 urllib3-1.26.13 xxhash-3.1.0 yarl-1.8.2\n",
      "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001B[0m\u001B[33m\n",
      "\u001B[0m\u001B[33mWARNING: You are using pip version 22.0.4; however, version 22.3.1 is available.\n",
      "You should consider upgrading via the '/usr/bin/python3.9 -m pip install --upgrade pip' command.\u001B[0m\u001B[33m\n",
      "\u001B[0mThu Dec 15 13:32:01 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |\n",
      "|  0%   29C    P0    69W / 300W |   2154MiB / 22731MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "! pip install tokenizer datasets sentencepiece protobuf==3.20.0\n",
    "! nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-10-22T21:58:16.116662Z",
     "iopub.status.busy": "2022-10-22T21:58:16.116143Z",
     "iopub.status.idle": "2022-10-22T21:58:17.279364Z",
     "shell.execute_reply": "2022-10-22T21:58:17.278660Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.13) or chardet (3.0.4) doesn't match a supported version!\n",
      "  warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "import time\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-10-22T21:58:17.284240Z",
     "iopub.status.busy": "2022-10-22T21:58:17.283941Z",
     "iopub.status.idle": "2022-10-22T21:58:24.270745Z",
     "shell.execute_reply": "2022-10-22T21:58:24.269721Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd37851a08234e43b091d332a8ab7348",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/882 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73f315cb16a94f5a9cd052e336d795a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/443M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d4a8cca9fc740d2831c65c74d309ab7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51a012f6514943a28f120b54665a1e6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/811k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71e85550a456461f9c969f9116d4aa11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/299 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d2ced9224344c3dad68f6d5a5867a72",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/8.78k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d2527955c464badbd1149d193a1744c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading metadata:   0%|          | 0.00/36.6k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6a25fb0b8774efabdcccd52221bac43",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/18.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset xnli/fr to /root/.cache/huggingface/datasets/xnli/fr/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a528f85b67884f08b8cfa8f1e35adcbb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9f983fb156c4733bc40edeb2d7063d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/466M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15067936fa0a4cbb81fb22e11e5663ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/17.9M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  "
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a3148a54e0d42399b448426cbb9bddd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files #1:   0%|          | 0/1 [00:00<?, ?obj/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15df8876991c45af9a5c0733d96d4879",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files #0:   0%|          | 0/1 [00:00<?, ?obj/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/5010 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/2490 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset xnli downloaded and prepared to /root/.cache/huggingface/datasets/xnli/fr/1.1.0/818164464f9c9fd15776ca8a00423b074344c3e929d00a2c1a84aa5a50c928bd. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59e73cbef55c447b875a899aab3edef6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name = \"BaptisteDoyen/camembert-base-xnli\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "model = model.eval().cuda()\n",
    "\n",
    "model_opt = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "model_opt = model_opt.eval().cuda()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "dataset = load_dataset(path=\"xnli\", name=\"fr\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we do a warmup, it builds the triton kernels optimized for each size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-10-22T21:58:24.280521Z",
     "iopub.status.busy": "2022-10-22T21:58:24.280032Z",
     "iopub.status.idle": "2022-10-22T22:02:02.893134Z",
     "shell.execute_reply": "2022-10-22T22:02:02.892462Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "370s\n"
     ]
    }
   ],
   "source": [
    "from kernl.model_optimization import optimize_model\n",
    "\n",
    "optimize_model(model_opt)\n",
    "start = time.perf_counter()\n",
    "shapes = [(1, w) for w in range(8, 128 + 8, 8)]\n",
    "with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):\n",
    "    for s in shapes:\n",
    "        inputs = {\n",
    "            \"input_ids\": torch.ones(s, device=\"cuda\", dtype=torch.long),\n",
    "            \"attention_mask\": torch.ones(s, device=\"cuda\", dtype=torch.long),\n",
    "        }\n",
    "        _ = model_opt(**inputs)\n",
    "        _ = model(**inputs)\n",
    "\n",
    "print(f\"{time.perf_counter() - start:.0f}s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-10-22T22:02:02.898068Z",
     "iopub.status.busy": "2022-10-22T22:02:02.897814Z",
     "iopub.status.idle": "2022-10-22T22:02:54.585617Z",
     "shell.execute_reply": "2022-10-22T22:02:54.585027Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete_time_baseline=38.08s\n",
      "complete_time_optimized=5.25s\n",
      "nb_disagree=1\n",
      "score baseline: 0.82\n",
      "score optimize: 0.82\n"
     ]
    }
   ],
   "source": [
    "complete_time_baseline = 0\n",
    "score_baseline = 0\n",
    "complete_time_optimized = 0\n",
    "score_optimize = 0\n",
    "nb_examples = len(dataset[\"test\"])\n",
    "nb_disagree = 0\n",
    "\n",
    "with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):\n",
    "    for index, content in enumerate(dataset[\"test\"]):\n",
    "        premise, hypothesis, label = content.values()\n",
    "        inputs = tokenizer(premise, hypothesis, return_tensors=\"pt\", pad_to_multiple_of=8, padding=True)\n",
    "        inputs = dict(inputs.to(\"cuda\"))\n",
    "\n",
    "        torch.cuda.synchronize()\n",
    "        start = time.perf_counter()\n",
    "        output_original = model(**inputs)\n",
    "        torch.cuda.synchronize()\n",
    "        complete_time_baseline += time.perf_counter() - start\n",
    "\n",
    "        choice_baseline = torch.argmax(output_original.logits, dim=1)\n",
    "        score_baseline += label == choice_baseline.item()\n",
    "\n",
    "        start = time.perf_counter()\n",
    "        output_optimized = model_opt(**inputs)\n",
    "        torch.cuda.synchronize()\n",
    "        complete_time_optimized += time.perf_counter() - start\n",
    "\n",
    "        choice_optimize = torch.argmax(output_optimized.logits, dim=1)\n",
    "        score_optimize += label == choice_optimize.item()\n",
    "\n",
    "        assert torch.allclose(\n",
    "            output_original.logits, output_optimized.logits, atol=1e-1\n",
    "        ), f\"logits don't match:\\n{output_original}\\n{output_optimized}\"\n",
    "        if choice_baseline != choice_optimize:\n",
    "            nb_disagree += 1\n",
    "\n",
    "print(f\"{complete_time_baseline=:.2f}s\")\n",
    "print(f\"{complete_time_optimized=:.2f}s\")\n",
    "print(f\"{nb_disagree=}\")\n",
    "print(f\"score baseline: {score_baseline / nb_examples:.2f}\")\n",
    "print(f\"score optimize: {score_optimize / nb_examples:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "033db6b4fbad45d38a2b7991eb0666fa": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "244d460a61434f0a99f1d179db9f66e0": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_93d776a4057b4d449d3280c7218aa232",
       "placeholder": "​",
       "style": "IPY_MODEL_7bed2deb71574c6390b221400a85f147",
       "tabbable": null,
       "tooltip": null,
       "value": " 3/3 [00:00&lt;00:00, 77.71it/s]"
      }
     },
     "2d8d2026815a4096b56a6e7b6ce1a3c7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_033db6b4fbad45d38a2b7991eb0666fa",
       "placeholder": "​",
       "style": "IPY_MODEL_f052bf584a244226af4ad40e79a3ee48",
       "tabbable": null,
       "tooltip": null,
       "value": "100%"
      }
     },
     "520958910806468c9f26566c7c3f7e0a": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "58495ed27cf042789ec5bae35e3cc079": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "653e426fbc374eddbfccfdbe0176395b": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "7bed2deb71574c6390b221400a85f147": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "93d776a4057b4d449d3280c7218aa232": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "adbeb356481b492896d7638b43f56c98": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_58495ed27cf042789ec5bae35e3cc079",
       "max": 3,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_520958910806468c9f26566c7c3f7e0a",
       "tabbable": null,
       "tooltip": null,
       "value": 3
      }
     },
     "b6b7ccdf218e4a57ba3f2541a409bf24": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_2d8d2026815a4096b56a6e7b6ce1a3c7",
        "IPY_MODEL_adbeb356481b492896d7638b43f56c98",
        "IPY_MODEL_244d460a61434f0a99f1d179db9f66e0"
       ],
       "layout": "IPY_MODEL_653e426fbc374eddbfccfdbe0176395b",
       "tabbable": null,
       "tooltip": null
      }
     },
     "f052bf584a244226af4ad40e79a3ee48": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}