{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# 🤗 and 3LC example on the IMDb dataset\n", "\n", "This notebook demonstrates fine-tuning a pretrained DistilBERT model from `transformers` on the `IMDb` dataset, using the 3LC integrations with `Trainer` and `datasets` from Hugging Face. 3LC metrics are collected before and after one epoch of training.\n", "\n", "![img](../images/huggingface-imdb.png)\n", "\n", "\n", "\n", "The notebook covers:\n", "\n", "- Creating a `Table` from a `datasets` dataset.\n", "- Fine-tuning a pretrained `transformers` model on the IMDb dataset with `TLCTrainer`.\n", "- Using a custom function for metrics collection." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "EPOCHS = 10\n", "TRAIN_BATCH_SIZE = 16\n", "EVAL_BATCH_SIZE = 256\n", "TRAIN_DATASET_NAME = \"hf-imdb-train\"\n", "EVAL_DATASET_NAME = \"hf-imdb-test\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "NUM_WORKERS = 4\n", "DEVICE = None\n", "PROJECT_NAME = \"3LC Tutorials - Hugging Face IMDB\"\n", "RUN_NAME = \"Train DistilBERT on IMDB\"\n", "DESCRIPTION = \"Example notebook for training a DistilBERT model on the IMDB dataset\"\n", "INSTALL_DEPENDENCIES = True" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "if INSTALL_DEPENDENCIES:\n", " %pip install scikit-learn\n", " %pip install \"3lc[huggingface]\" \"transformers<=4.56.0\"\n", " %pip install git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import datasets\n", "import evaluate\n", "import numpy as np\n", "import tlc\n", "import torch\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments\n", "\n", "os.environ[\"TRANSFORMERS_NO_ADVISORY_WARNINGS\"] = \"true\" # Removing DistilBertTokenizerFast tokenizer warning\n", "\n", "datasets.utils.logging.disable_progress_bar()" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "if DEVICE is None:\n", " if torch.cuda.is_available():\n", " DEVICE = \"cuda\"\n", " elif torch.backends.mps.is_available():\n", " DEVICE = \"mps\"\n", " else:\n", " DEVICE = \"cpu\"" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Initialize a 3LC Run\n", "\n", "We initialize a Run with a call to `tlc.init`, and add the configuration to the Run object." ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"epochs\": EPOCHS,\n", " \"train_batch_size\": TRAIN_BATCH_SIZE,\n", " \"eval_batch_size\": EVAL_BATCH_SIZE,\n", "}\n", "\n", "run = tlc.init(\n", " project_name=PROJECT_NAME,\n", " run_name=RUN_NAME,\n", " description=DESCRIPTION,\n", " parameters=config,\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "With the 3LC integration, you can use `tlc.Table.from_hugging_face()` as a drop-in replacement for\n", "`datasets.load_dataset()` to create a `tlc.Table`. Notice `.latest()`, which gets the latest version of the 3LC dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "train_dataset = tlc.Table.from_hugging_face(\n", " \"imdb\",\n", " split=\"train\",\n", " project_name=PROJECT_NAME,\n", " dataset_name=TRAIN_DATASET_NAME,\n", " description=\"IMDB train dataset\",\n", " if_exists=\"overwrite\",\n", ")\n", "\n", "eval_dataset = tlc.Table.from_hugging_face(\n", " \"imdb\",\n", " split=\"test\",\n", " project_name=PROJECT_NAME,\n", " dataset_name=EVAL_DATASET_NAME,\n", " description=\"IMDB test dataset\",\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "You can use the data produced by these Tables like you would with a 🤗 dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "train_dataset_hf = datasets.load_dataset(\"imdb\", split=\"train\")\n", "train_dataset_hf[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": {}, "outputs": [], "source": [ "train_dataset[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "from tlc_tools.split import split_table\n", "\n", "splits = split_table(\n", " train_dataset, splits={\"train-subset\": 0.01, \"eval-subset\": 0.005, \"dontcare\": 0.985}, if_exists=\"rename\"\n", ")\n", "\n", "train_dataset = splits[\"train-subset\"]\n", "eval_dataset = splits[\"eval-subset\"]" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "`Table` provides a method `map` to apply both preprocessing and on-the-fly transforms to your data before it is sent to the model." ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\", model_max_length=512)\n", "\n", "\n", "def tokenize(sample):\n", " return {**sample, **tokenizer(sample[\"text\"], truncation=True)}" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "train_tokenized = train_dataset.map(tokenize)\n", "eval_tokenized = eval_dataset.map(tokenize)" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "id": "19", "metadata": {}, "outputs": [], "source": [ "id2label = {0: \"neg\", 1: \"pos\"}\n", "label2id = {\"neg\": 0, \"pos\": 1}\n", "\n", "# For demonstration purposes, we use the distilbert-base-uncased model with a different set of labels than\n", "# it was trained on. As a result, there will be a warning about the inconsistency of the classifier and\n", "# pre_classifier weights. This is expected and can be ignored.\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"distilbert-base-uncased\", num_labels=2, id2label=id2label, label2id=label2id\n", ")" ] }, { "cell_type": "markdown", "id": "20", "metadata": {}, "source": [ "## Setup Metrics Collection\n", "\n", "Computing metrics is done by implementing a function which returns per-sample metrics you would like to see in the 3LC Dashboard. \n", "\n", "We keep the metrics function in Hugging Face to see the intermediate aggregate metrics.\n", "\n", "For special metrics such as the predicted category we specify that we would like this to be shown as a `CategoricalLabel`. " ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": {}, "outputs": [], "source": [ "accuracy = evaluate.load(\"accuracy\")\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.argmax(predictions, axis=1)\n", " return accuracy.compute(predictions=predictions, references=labels)\n", "\n", "\n", "def compute_tlc_metrics(logits, labels):\n", " probabilities = torch.nn.functional.softmax(logits, dim=-1)\n", "\n", " predictions = logits.argmax(dim=-1)\n", " loss = torch.nn.functional.cross_entropy(logits, labels, reduction=\"none\")\n", " confidence = probabilities.gather(dim=-1, index=predictions.unsqueeze(-1)).squeeze()\n", "\n", " return {\n", " \"predicted\": predictions,\n", " \"loss\": loss,\n", " \"confidence\": confidence,\n", " }\n", "\n", "\n", "compute_tlc_metrics.column_schemas = {\n", " \"predicted\": tlc.CategoricalLabelSchema(id2label),\n", " \"loss\": tlc.Float32Schema(),\n", " \"confidence\": tlc.Float32Schema(),\n", "}" ] }, { "attachments": {}, "cell_type": "markdown", "id": "22", "metadata": {}, "source": [ "## Train the model with TLCTrainer\n", "\n", "To perform model training, we replace the usual `Trainer` with `TLCTrainer` and provide the per-sample metrics collection function. We also specify that we would like to collect metrics prior to training." ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": {}, "outputs": [], "source": [ "from tlc.integration.hugging_face import TLCTrainer\n", "\n", "training_args = TrainingArguments(\n", " output_dir=DOWNLOAD_PATH,\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=TRAIN_BATCH_SIZE,\n", " per_device_eval_batch_size=EVAL_BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=0.01,\n", " report_to=\"none\",\n", " eval_strategy=\"epoch\",\n", " use_cpu=DEVICE == \"cpu\",\n", " dataloader_num_workers=NUM_WORKERS,\n", " # disable_tqdm=True,\n", ")\n", "\n", "trainer = TLCTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_tokenized,\n", " eval_dataset=eval_tokenized,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_hf_metrics=compute_metrics,\n", " compute_tlc_metrics=compute_tlc_metrics,\n", " compute_tlc_metrics_on_train_begin=True,\n", " tlc_metrics_collection_epoch_frequency=1,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "24", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }