{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "You will need an authentication token with your Hugging Face credentials to use the `push_to_hub` method. Execute `huggingface-cli login` in your terminal or by uncommenting the following cell:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# !huggingface-cli login" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from datasets import load_dataset, load_metric\n", "from transformers import (\n", " AutoModelForSequenceClassification,\n", " AutoTokenizer,\n", " DataCollatorWithPadding,\n", " Trainer,\n", " TrainingArguments,\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "checkpoint = \"bert-base-cased\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8174fd92eed0af98.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8c99fb059544bc96.arrow\n", "Loading cached processed dataset at /home/sgugger/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e625eb72bcf1ae1f.arrow\n", "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True)\n", "\n", "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n", "\n", "training_args = TrainingArguments(\n", " \"finetuned-bert-mrpc\",\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " learning_rate=2e-5,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " log_level=\"error\",\n", " push_to_hub=True,\n", " push_to_hub_model_id=\"finetuned-bert-mrpc\",\n", " # push_to_hub_organization=\"huggingface\",\n", " # push_to_hub_token=\"my_token\",\n", ")\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer)\n", "\n", "metric = load_metric(\"glue\", \"mrpc\")\n", "\n", "def compute_metrics(eval_preds):\n", " logits, labels = eval_preds\n", " predictions = np.argmax(logits, axis=-1)\n", " return metric.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model,\n", " training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "F1 | \n", "
|---|---|---|---|---|
| 1 | \n", "0.573500 | \n", "0.475627 | \n", "0.774510 | \n", "0.835714 | \n", "
| 2 | \n", "0.383800 | \n", "0.452076 | \n", "0.821078 | \n", "0.878939 | \n", "
| 3 | \n", "0.233600 | \n", "0.485343 | \n", "0.850490 | \n", "0.897133 | \n", "
"
],
"text/plain": [
"