{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "# Quantizing a model with Intel Neural Compressor (INC) for text classification tasks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows how to apply different quantization approaches such as dynamic, static and aware training quantization, using the [Intel Neural Compressor](https://github.com/intel/neural-compressor) (INC) library, for any tasks of the GLUE benchmark. This is made possible thanks to 🤗 [Optimum](https://github.com/huggingface/optimum), an extension of 🤗 [Transformers](https://github.com/huggingface/transformers), providing a set of performance optimization tools enabling maximum efficiency to train and run models on targeted hardwares. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets and 🤗 Optimum. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install datasets transformers optimum[intel]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of 🤗 Optimum is at least 1.2.3 since the functionality was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.2.3\n" ] } ], "source": [ "from optimum.intel.version import __version__\n", "\n", "print(__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that quantization is currently only supported for CPUs, so we will not be utilizing GPUs / CUDA in this notebook. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:\n", "\n", "- [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.\n", "- [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.\n", "- [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.\n", "- [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. This dataset is built from the SQuAD dataset.\n", "- [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.\n", "- [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.\n", "- [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.\n", "- [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.\n", "- [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. This dataset is built from the Winograd Schema Challenge dataset.\n", "\n", "We will see how to apply post-training static quantization on a DistilBERT model fine-tuned on the SST-2 task:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", "task = \"sst2\"\n", "model_checkpoint = \"distilbert-base-uncased-finetuned-sst-2-english\"\n", "batch_size = 16\n", "max_train_samples = 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can set our `quantization_approach` to either `dynamic`, `static` or `aware_training` in order to apply respectively dynamic, static and aware training quantization. \n", "- Post-training static quantization : introduces an additional calibration step where data is fed through the network in order to compute the activations quantization parameters.\n", "- Post-training dynamic quantization : dynamically computes activations quantization parameters based on the data observed at runtime.\n", "- Quantization aware training : simulates the effects of quantization during training in order to alleviate its effects on the model's performance.\n", "\n", "Quantization will be applied on the embeddings, and on the linear layers as well as on their corresponding input activations." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "SUPPORTED_QUANTIZATION_APPROACH = [\"dynamic\", \"static\", \"aware_training\"]\n", "\n", "quantization_approach = \"static\"" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our quantized model to the baseline). This can be easily done with the functions `load_dataset` and `load_metric`. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "Apart from `mnli-mm` being a special code, we can directly pass our task name to those functions. `load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-06-14 15:28:50 [WARNING] Reusing dataset glue (/home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc4ed01451ab449c800ca5ffd30dcbc7", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", " \n", " \n", " [54/54 02:38]\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "The full-precision model has an eval_accuracy of 91.09.\n" ] } ], "source": [ "metric_name = \"eval_\" + (\"pearson\" if task == \"stsb\" else \"matthews_correlation\" if task == \"cola\" else \"accuracy\")\n", "\n", "def eval_func(model):\n", " trainer.model = model\n", " metrics = trainer.evaluate()\n", " return metrics.get(metric_name)\n", "\n", "fp_model_result = eval_func(fp_model)\n", "print(f\"The full-precision model has an {metric_name} of {round(fp_model_result * 100, 2)}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We instantiate `IncQuantizationConfig` using a configuration file containing all the informations related to quantization and tuning objective. We can set the quantization approach as well as the accuracy target, currently tolerating a 0.02 relative performance drop when compared to our baseline which is the full-precision model." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from optimum.intel.neural_compressor import IncQuantizationConfig, IncQuantizationMode\n", "\n", "config = \"echarlaix/bert-base-uncased-sst2-static-quant-test\"\n", "q8_config = IncQuantizationConfig.from_pretrained(config, config_file_name=\"quantization.yml\")\n", "\n", "accuracy_criterion = 0.02\n", "q8_config.set_config(\"tuning.accuracy_criterion.relative\", accuracy_criterion)\n", "q8_approach = getattr(IncQuantizationMode, quantization_approach.upper()).value\n", "q8_config.set_config(\"quantization.approach\", q8_approach)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For both static and aware training quantization, we use PyTorch FX Graph Mode Quantization." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "if quantization_approach != \"dynamic\":\n", " q8_config.set_config(\"model.framework\", \"pytorch_fx\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To instantiate an `IncQuantizer`, we need a configuration containing all the informations relative to quantization and tuning (which can be either a path to a YAML file or an `IncQuantizationConfig` object), the model to quantize and finally an evaluation function which will be used to evaluate the quantization impact and thus verify if it fits the tolerance defined by the user.\n", "\n", "In the case of static quantization, our `IncQuantizer` will also need a calibration dataloader in order to perform the calibration step.\n", "\n", "In the case of aware training quantization, it will need a training function, the latter will be used to perform the training will applying quantization.\n", "\n", "We can now instantiate our `IncOptimizer` which will take care of the quantization process." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "2022-06-14 15:29:34 [INFO] Start sequential pipeline execution.\n", "2022-06-14 15:29:34 [INFO] The 0th step being executing is QUANTIZATION.\n", "2022-06-14 15:29:34 [INFO] Pass query framework capability elapsed time: 166.45 ms\n", "2022-06-14 15:29:34 [INFO] Get FP32 model baseline.\n", "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "***** Running Evaluation *****\n", " Num examples = 872\n", " Batch size = 16\n", "2022-06-14 15:30:10 [INFO] Save tuning history to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/./history.snapshot.\n", "2022-06-14 15:30:10 [INFO] FP32 baseline is: [Accuracy: 0.9109, Duration (seconds): 36.2036]\n", "/home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/qconfig.py:88: UserWarning: QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead\n", " warnings.warn(\"QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead\")\n", "2022-06-14 15:30:10 [INFO] Fx trace of the entire model failed, We will conduct auto quantization\n", "/home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/observer.py:177: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n", " warnings.warn(\n", "2022-06-14 15:30:10 [WARNING] Please note that calibration sampling size 100 isn't divisible exactly by batch size 16. So the real sampling size is 112.\n", "/home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/nn/quantized/_reference/modules/linear.py:41: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " torch.tensor(weight_qparams[\"scale\"], dtype=torch.float, device=device))\n", "/home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/nn/quantized/_reference/modules/linear.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " torch.tensor(\n", "2022-06-14 15:30:18 [INFO] |*********Mixed Precision Statistics********|\n", "2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+\n", "2022-06-14 15:30:18 [INFO] | Op Type | Total | INT8 | FP32 |\n", "2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+\n", "2022-06-14 15:30:18 [INFO] | Embedding | 2 | 2 | 0 |\n", "2022-06-14 15:30:18 [INFO] | LayerNorm | 13 | 0 | 13 |\n", "2022-06-14 15:30:18 [INFO] | quantize_per_tensor | 38 | 38 | 0 |\n", "2022-06-14 15:30:18 [INFO] | Linear | 38 | 38 | 0 |\n", "2022-06-14 15:30:18 [INFO] | dequantize | 38 | 38 | 0 |\n", "2022-06-14 15:30:18 [INFO] | Dropout | 6 | 0 | 6 |\n", "2022-06-14 15:30:18 [INFO] +---------------------+-------+------+------+\n", "2022-06-14 15:30:18 [INFO] Pass quantize model elapsed time: 8274.07 ms\n", "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "***** Running Evaluation *****\n", " Num examples = 872\n", " Batch size = 16\n", "2022-06-14 15:30:43 [INFO] Tune 1 result is: [Accuracy (int8|fp32): 0.9005|0.9109, Duration (seconds) (int8|fp32): 24.7176|36.2036], Best tune result is: [Accuracy: 0.9005, Duration (seconds): 24.7176]\n", "2022-06-14 15:30:43 [INFO] |**********************Tune Result Statistics**********************|\n", "2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+\n", "2022-06-14 15:30:43 [INFO] | Info Type | Baseline | Tune 1 result | Best tune result |\n", "2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+\n", "2022-06-14 15:30:43 [INFO] | Accuracy | 0.9109 | 0.9005 | 0.9005 |\n", "2022-06-14 15:30:43 [INFO] | Duration (seconds) | 36.2036 | 24.7176 | 24.7176 |\n", "2022-06-14 15:30:43 [INFO] +--------------------+----------+---------------+------------------+\n", "2022-06-14 15:30:43 [INFO] Save tuning history to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/./history.snapshot.\n", "2022-06-14 15:30:43 [INFO] Specified timeout or max trials is reached! Found a quantized model which meet accuracy goal. Exit.\n", "2022-06-14 15:30:43 [INFO] Save deploy yaml to /home/ella/Projects/huggingface/notebooks/examples/nc_workspace/2022-06-14_15-28-47/deploy.yaml\n" ] } ], "source": [ "from optimum.intel.neural_compressor import IncQuantizer, IncOptimizer\n", "\n", "quantizer = IncQuantizer(\n", " config_path_or_obj=q8_config,\n", " eval_func=eval_func,\n", " train_func=train_func if quantization_approach == \"aware_training\" else None,\n", " calib_dataloader=trainer.get_train_dataloader() if quantization_approach == \"static\" else None,\n", ")\n", "\n", "optimizer = IncOptimizer(fp_model, quantizer=quantizer)\n", "q_model = optimizer.fit()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "***** Running Evaluation *****\n", " Num examples = 872\n", " Batch size = 16\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The resulting quantized model has an eval_accuracy of 90.05.\n", "This results in a drop of 1.04 in eval_accuracy when compared to the full-precision model.\n" ] } ], "source": [ "q_model_result = eval_func(q_model.model)\n", "print(f\"The resulting quantized model has an {metric_name} of {round(q_model_result * 100, 2)}.\")\n", "print(f\"This results in a drop of {round((fp_model_result - q_model_result) * 100, 2)} in {metric_name} when compared to the full-precision model.\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The full-precision model size is 255 MB while the quantized model one is 65 MB.\n", "The quantized model is 3.93x smaller than the full-precision one.\n" ] } ], "source": [ "import torch\n", "\n", "def get_model_size(model):\n", " torch.save(model.state_dict(), \"tmp.pt\")\n", " model_size = os.path.getsize(\"tmp.pt\") / (1024*1024)\n", " os.remove(\"tmp.pt\")\n", " return round(model_size, 2)\n", "\n", "fp_model_size = get_model_size(fp_model) \n", "q_model_size = get_model_size(q_model.model) \n", "\n", "print(f\"The full-precision model size is {round(fp_model_size)} MB while the quantized model one is {round(q_model_size)} MB.\")\n", "print(f\"The quantized model is {round(fp_model_size / q_model_size, 2)}x smaller than the full-precision one.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We save the resulting quantized model as well as its configuration." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Configuration saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "2022-06-14 15:31:09 [INFO] Model weights saved to distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2\n" ] } ], "source": [ "optimizer.save_pretrained(output)" ] }, { "cell_type": "markdown", "metadata": { "id": "7k8ge1L1IrJk" }, "source": [ "## Loading the quantized model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The previously saved config file containing all the informations relative to the model quantization is used to instantiate an`IncOptimizedConfig`. We then load the model using `IncQuantizedModelForSequenceClassification`." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "YUdakNBhIrJl" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.19.4\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.19.4\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading weights file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/pytorch_model.bin\n", "All model checkpoint weights were used when initializing DistilBertForSequenceClassification.\n", "\n", "All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.\n", "/home/ella/miniconda3/envs/optimum_inc/lib/python3.8/site-packages/torch/ao/quantization/observer.py:1124: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point \n", " warnings.warn(\n", "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence. If idx, sentence are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "***** Running Evaluation *****\n", " Num examples = 872\n", " Batch size = 16\n" ] } ], "source": [ "from optimum.intel.neural_compressor.quantization import IncQuantizedModelForSequenceClassification\n", "\n", "loaded_model = IncQuantizedModelForSequenceClassification.from_pretrained(output)\n", "loaded_model.eval()\n", "loaded_model_result = eval_func(loaded_model)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The quantized model was successfully loaded.\n" ] } ], "source": [ "if loaded_model_result == q_model_result:\n", " print(\"The quantized model was successfully loaded.\")\n", "else:\n", " print(\"The quantized model was NOT successfully loaded.\")" ] } ], "metadata": { "colab": { "name": "Text Classification on GLUE", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 1 }