{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "# Quantizing a model during fine-tuning with Intel Neural Compressor (INC) for text classification tasks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows how to apply quantization aware training, using the [Intel Neural Compressor](https://github.com/intel/neural-compressor) (INC) library, for any tasks of the GLUE benchmark. This is made possible thanks to 🤗 [Optimum Intel](https://github.com/huggingface/optimum-intel), an extension of 🤗 [Transformers](https://github.com/huggingface/transformers), providing a set of performance optimization tools enabling maximum efficiency to accelerate end-to-end pipelines on a variety of Intel processors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets and 🤗 Optimum. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install datasets transformers optimum[neural-compressor]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of 🤗 Optimum is at least 1.6.0 since the functionality was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.7.0.dev0\n" ] } ], "source": [ "from optimum.intel.version import __version__\n", "\n", "print(__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:\n", "\n", "- [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.\n", "- [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.\n", "- [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.\n", "- [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. This dataset is built from the SQuAD dataset.\n", "- [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.\n", "- [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.\n", "- [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.\n", "- [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.\n", "- [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. This dataset is built from the Winograd Schema Challenge dataset.\n", "\n", "We will see how to apply post-training static quantization on a DistilBERT model fine-tuned on the SST-2 task:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n", "task = \"sst2\"\n", "model_checkpoint = \"distilbert-base-uncased-finetuned-sst-2-english\"\n", "batch_size = 16\n", "max_train_samples = 200\n", "max_eval_samples = 200" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) and [🤗 Evaluate](https://github.com/huggingface/evaluate) libraries to download the data and get the metric we need to use for evaluation. This can be easily done with the functions `load_dataset` and `load`.\n", "\n", "Apart from `mnli-mm` being a special code, we can directly pass our task name to those functions. `load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset glue (/home/ella/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b9d3aa274c11409a9e7a2e6bc51b18af", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00\n", " \n", " \n", " [13/13 01:14, Epoch 1/1]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
1No log0.5462450.755000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 200\n", " Batch size = 16\n", "Configuration saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/checkpoint-13/config.json\n", "tokenizer config file saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/checkpoint-13/tokenizer_config.json\n", "Special tokens file saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/checkpoint-13/special_tokens_map.json\n", "Loading best model from distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/checkpoint-13 (score: 0.755).\n", "There were unexpected keys in the checkpoint model loaded: ['best_configure'].\n", "2023-01-13 13:07:01 [INFO] |********Mixed Precision Statistics*******|\n", "2023-01-13 13:07:01 [INFO] +------------------------+--------+-------+\n", "2023-01-13 13:07:01 [INFO] | Op Type | Total | INT8 |\n", "2023-01-13 13:07:01 [INFO] +------------------------+--------+-------+\n", "2023-01-13 13:07:01 [INFO] | Embedding | 2 | 2 |\n", "2023-01-13 13:07:01 [INFO] | quantize_per_tensor | 51 | 51 |\n", "2023-01-13 13:07:01 [INFO] | LayerNorm | 13 | 13 |\n", "2023-01-13 13:07:01 [INFO] | dequantize | 51 | 51 |\n", "2023-01-13 13:07:01 [INFO] | Linear | 38 | 38 |\n", "2023-01-13 13:07:01 [INFO] | Dropout | 6 | 6 |\n", "2023-01-13 13:07:01 [INFO] +------------------------+--------+-------+\n", "2023-01-13 13:07:01 [INFO] Training finished!\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=13, training_loss=0.5614617054279034, metrics={'train_runtime': 80.3063, 'train_samples_per_second': 2.49, 'train_steps_per_second': 0.162, 'total_flos': 6623369932800.0, 'train_loss': 0.5614617054279034, 'epoch': 1.0})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can run evaluation by just calling the `evaluate` method:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 200\n", " Batch size = 16\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [13/13 00:07]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.5514087080955505,\n", " 'eval_accuracy': 0.795,\n", " 'eval_runtime': 8.0226,\n", " 'eval_samples_per_second': 24.929,\n", " 'eval_steps_per_second': 1.62,\n", " 'epoch': 1.0}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The full-precision model size is 255 MB while the quantized model one is 65 MB.\n", "The quantized model is 3.93x smaller than the full-precision one.\n" ] } ], "source": [ "import os\n", "import torch\n", "\n", "def get_model_size(model):\n", " torch.save(model.state_dict(), \"tmp.pt\")\n", " model_size = os.path.getsize(\"tmp.pt\") / (1024*1024)\n", " os.remove(\"tmp.pt\")\n", " return round(model_size, 2)\n", "\n", "fp_model_size = get_model_size(fp_model)\n", "q_model_size = get_model_size(trainer.model)\n", "\n", "print(f\"The full-precision model size is {round(fp_model_size)} MB while the quantized model one is {round(q_model_size)} MB.\")\n", "print(f\"The quantized model is {round(fp_model_size / q_model_size, 2)}x smaller than the full-precision one.\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To save the resulting quantized model, you can use the `save_model` method. By setting `save_onnx_model` to `True`, the model will be additionnaly exported to the ONNX format.\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Configuration saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "tokenizer config file saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/tokenizer_config.json\n", "Special tokens file saved in distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/special_tokens_map.json\n", "2023-01-13 13:07:10 [WARNING] QDQ format requires opset_version >= 13, we reset opset_version=13 here\n", "/home/ella/miniconda3/envs/inc/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:217: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n", " mask, torch.tensor(torch.finfo(scores.dtype).min)\n", "2023-01-13 13:07:12 [INFO] Weight type: QInt8.\n", "2023-01-13 13:07:12 [INFO] Activation type: QUInt8.\n", "WARNING:root:Please consider pre-processing before quantization. See https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md \n", "WARNING:root:Please consider pre-processing before quantization. See https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md \n", "2023-01-13 13:07:38 [INFO] ******************************************************************************************************************\n", "2023-01-13 13:07:38 [INFO] The INT8 ONNX Model is exported to path: distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/model.onnx\n", "2023-01-13 13:07:38 [INFO] ******************************************************************************************************************\n" ] } ], "source": [ "trainer.save_model(save_onnx_model=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "7k8ge1L1IrJk" }, "source": [ "## Loading the quantized model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You must instantiate you model using our `INCModelForXxx`[https://huggingface.co/docs/optimum/main/intel/reference_inc#optimum.intel.neural_compressor.INCModel] or `ORTModelForXxx`[https://huggingface.co/docs/optimum/onnxruntime/package_reference/modeling_ort] classes to load respectively your quantized PyTorch or ONNX model hosted locally or on the 🤗 hub :" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "YUdakNBhIrJl" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"framework\": \"pytorch_fx\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.25.1\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"framework\": \"pytorch_fx\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.25.1\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading weights file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/pytorch_model.bin\n", "All model checkpoint weights were used when initializing DistilBertForSequenceClassification.\n", "\n", "All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2.\n", "If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training.\n", "2023-01-13 13:07:39 [WARNING] Please provide the example_inputs or a dataloader to get example_inputs for quantized model.\n", "2023-01-13 13:07:39 [INFO] Fx trace of the entire model failed. We will conduct auto quantization\n", "/home/ella/miniconda3/envs/inc/lib/python3.9/site-packages/torch/ao/quantization/utils.py:287: UserWarning: must run observer before calling calculate_qparams. Returning default values.\n", " warnings.warn(\n", "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"framework\": \"pytorch_fx\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.25.1\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading file vocab.txt\n", "loading file tokenizer.json\n", "loading file added_tokens.json\n", "loading file special_tokens_map.json\n", "loading file tokenizer_config.json\n", "loading configuration file distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2/config.json\n", "Model config DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english-finetuned-sst2\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"framework\": \"pytorch_fx\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"problem_type\": \"single_label_classification\",\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"torch_dtype\": \"int8\",\n", " \"transformers_version\": \"4.25.1\",\n", " \"vocab_size\": 30522\n", "}\n", "\n", "loading file vocab.txt\n", "loading file tokenizer.json\n", "loading file added_tokens.json\n", "loading file special_tokens_map.json\n", "loading file tokenizer_config.json\n" ] } ], "source": [ "from optimum.intel.neural_compressor import INCModelForSequenceClassification\n", "from optimum.onnxruntime import ORTModelForSequenceClassification\n", "\n", "pytorch_model = INCModelForSequenceClassification.from_pretrained(save_directory)\n", "onnx_model = ORTModelForSequenceClassification.from_pretrained(save_directory)" ] } ], "metadata": { "colab": { "name": "Text Classification on GLUE", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "vscode": { "interpreter": { "hash": "0e8f1b1a672c99a0563c287a44a837e72a39530c4cccf2a666504641a8d1770c" } } }, "nbformat": 4, "nbformat_minor": 1 }