{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune Hugging Face SegFormer on a custom dataset\n", "\n", "This tutorial covers metrics collection on a custom semantic segmentation dataset using `3lc` and training using 🤗 `transformers`.\n", "\n", "It is based on the original notebook found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb).\n", "\n", "![](../images/semseg-ade20k-finetune.png)\n", "\n", "\n", "\n", "A small subset of the [ADE20K\n", "dataset](https://groups.csail.mit.edu/vision/datasets/ADE20K/) is used for this\n", "tutorial. The subset consists of 5 training images and 5 validation images, with\n", "semantic masks containing 150 labels.\n", "\n", "During training, per-sample loss, embeddings, and predictions are collected." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - Semantic Segmentation ADE20k\"\n", "DATASET_NAME = \"ADE20k_toy_dataset\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "EPOCHS = 200\n", "NUM_WORKERS = 0\n", "BATCH_SIZE = 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc[huggingface] \"transformers<=4.56.0\"\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "from pathlib import Path\n", "\n", "import tlc\n", "import torch\n", "from huggingface_hub import hf_hub_download\n", "from PIL import Image\n", "from torch.utils.data import DataLoader, Dataset\n", "from tqdm.notebook import tqdm\n", "from transformers import SegformerImageProcessor\n", "\n", "from tlc_tools.common import download_and_extract_zipfile, infer_torch_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATASET_ROOT = (Path(DOWNLOAD_PATH) / \"ADE20k_toy_dataset\").resolve()\n", "\n", "if not DATASET_ROOT.exists():\n", " print(\"Downloading data...\")\n", " download_and_extract_zipfile(\n", " url=\"https://www.dropbox.com/s/l1e45oht447053f/ADE20k_toy_dataset.zip?dl=1\",\n", " location=DOWNLOAD_PATH,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fetch the label map from the Hugging Face Hub" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load id2label mapping from a JSON on the hub\n", "repo_id = \"huggingface/label-files\"\n", "filename = \"ade20k-id2label.json\"\n", "with open(hf_hub_download(repo_id=repo_id, filename=filename, repo_type=\"dataset\")) as f:\n", " id2label = json.load(f)\n", "id2label = {int(k): v for k, v in id2label.items()}\n", "label2id = {v: k for k, v in id2label.items()}\n", "unreduced_label_map = {0.0: \"background\", **{k + 1: v for k, v in id2label.items()}}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "id2label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize a Run" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DEVICE = infer_torch_device()\n", "\n", "run = tlc.init(\n", " PROJECT_NAME,\n", " description=\"Train a SegFormer model on ADE20k toy dataset\",\n", " parameters={\n", " \"epochs\": EPOCHS,\n", " \"batch_size\": BATCH_SIZE,\n", " \"device\": str(DEVICE),\n", " },\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Torch Datasets and 3LC Tables" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SemanticSegmentationDataset(Dataset):\n", " \"\"\"Image (semantic) segmentation dataset.\"\"\"\n", "\n", " def __init__(self, root_dir: str, train: bool = True):\n", " \"\"\"\n", " :param root_dir: Root directory of the dataset containing the images + annotations.\n", " :param train: Whether to load \"training\" or \"validation\" images + annotations.\n", " \"\"\"\n", " self.root_dir = root_dir\n", " self.train = train\n", "\n", " sub_path = \"training\" if self.train else \"validation\"\n", " self.img_dir = os.path.join(self.root_dir, \"images\", sub_path)\n", " self.ann_dir = os.path.join(self.root_dir, \"annotations\", sub_path)\n", "\n", " # read images\n", " image_file_names = []\n", " for _, _, files in os.walk(self.img_dir):\n", " image_file_names.extend(files)\n", " self.images = sorted(image_file_names)\n", "\n", " # read annotations\n", " annotation_file_names = []\n", " for _, _, files in os.walk(self.ann_dir):\n", " annotation_file_names.extend(files)\n", " self.annotations = sorted(annotation_file_names)\n", "\n", " assert len(self.images) == len(self.annotations), \"There must be as many images as there are segmentation maps\"\n", "\n", " def __len__(self):\n", " return len(self.images)\n", "\n", " def __getitem__(self, idx):\n", " image = Image.open(os.path.join(self.img_dir, self.images[idx]))\n", " segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))\n", "\n", " # We need to include the original segmentation map size, in order to post-process the model output\n", " return image, segmentation_map, (segmentation_map.size[1], segmentation_map.size[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataset = SemanticSegmentationDataset(root_dir=DATASET_ROOT, train=True)\n", "val_dataset = SemanticSegmentationDataset(root_dir=DATASET_ROOT, train=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dataset[0][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the Tables" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "structure = (\n", " tlc.PILImage(\"image\"),\n", " tlc.SegmentationPILImage(\"segmentation_map\", classes=unreduced_label_map),\n", " tlc.HorizontalTuple(\"mask size\", [tlc.Int(\"width\"), tlc.Int(\"height\")]),\n", ")\n", "\n", "train_table = tlc.Table.from_torch_dataset(\n", " train_dataset,\n", " structure,\n", " project_name=PROJECT_NAME,\n", " dataset_name=DATASET_NAME,\n", " table_name=\"train\",\n", " if_exists=\"overwrite\",\n", ")\n", "\n", "val_table = tlc.Table.from_torch_dataset(\n", " val_dataset,\n", " structure,\n", " project_name=PROJECT_NAME,\n", " dataset_name=DATASET_NAME,\n", " table_name=\"val\",\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MapFn:\n", " def __init__(self, image_processor: SegformerImageProcessor):\n", " self.image_processor = image_processor\n", "\n", " def __call__(self, sample):\n", " image, segmentation_map, mask_size = sample\n", " encoded_inputs = self.image_processor(image, segmentation_map, return_tensors=\"pt\")\n", "\n", " for k, _ in encoded_inputs.items():\n", " encoded_inputs[k].squeeze_() # remove batch dimension\n", "\n", " encoded_inputs.update({\"mask_size\": torch.tensor(mask_size)})\n", "\n", " return encoded_inputs\n", "\n", "\n", "image_processor = SegformerImageProcessor(reduce_labels=True)\n", "\n", "# Apply the image processor to the datasets\n", "train_table.map(MapFn(image_processor))\n", "val_table.map(MapFn(image_processor))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_table[0].keys()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_table.url" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import SegformerForSemanticSegmentation\n", "\n", "model = SegformerForSemanticSegmentation.from_pretrained(\n", " \"nvidia/mit-b0\",\n", " num_labels=150,\n", " id2label=id2label,\n", " label2id=label2id,\n", ").to(DEVICE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Predict on single sample\n", "model(train_table[0][\"pixel_values\"].unsqueeze(0).to(DEVICE))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup metrics collection" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. EmbeddingsMetricsCollector to collect hidden layer activations\n", "for ind, layer in enumerate(model.named_modules()):\n", " print(ind, \"=>\", layer[0])\n", "\n", "# Interesting layers for embedding collection:\n", "# - segformer.encoder.layer_norm.3 (Index: 197)\n", "# - decode_head.linear_c.2.proj (Index: 204)\n", "# - decode_head.linear_c.3.proj (Index: 207)\n", "\n", "layers = [197, 204, 207]\n", "\n", "embedding_collector = tlc.EmbeddingsMetricsCollector(layers=layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 2. A metrics collection callable to collect per-sample loss\n", "\n", "\n", "def metrics_fn(batch, predictor_output):\n", " labels = batch[\"labels\"].to(DEVICE)\n", " logits = predictor_output.forward.logits\n", " upsampled_logits = torch.nn.functional.interpolate(\n", " logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " loss = torch.nn.functional.cross_entropy(upsampled_logits, labels, reduction=\"none\", ignore_index=255)\n", " loss = loss.mean(dim=(1, 2))\n", " return {\"loss\": loss.detach().cpu().numpy()}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 3. A SegmentationMetricsCollector to write out the predictions\n", "\n", "\n", "def preprocess_fn(batch, predictor_output: tlc.PredictorOutput):\n", " \"\"\"Convert logits to masks with the same size as the input, un-reduce the labels\"\"\"\n", " processed_masks = image_processor.post_process_semantic_segmentation(\n", " predictor_output.forward,\n", " batch[\"mask_size\"].tolist(),\n", " )\n", " for i in range(len(processed_masks)):\n", " mask = processed_masks[i]\n", " mask[mask == 255] = 0\n", " mask = mask + 1\n", " processed_masks[i] = mask\n", "\n", " return batch, processed_masks\n", "\n", "\n", "segmentation_collector = tlc.SegmentationMetricsCollector(label_map=unreduced_label_map, preprocess_fn=preprocess_fn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define a single function to collect all metrics\n", "\n", "# A Predictor object wraps the model and enables embedding-collection\n", "predictor = tlc.Predictor(model, device=DEVICE, layers=layers)\n", "\n", "# Control the arguments used for the dataloader used during metrics collection\n", "mc_dataloader_args = {\"batch_size\": BATCH_SIZE}\n", "\n", "\n", "def collect_metrics(epoch):\n", " tlc.collect_metrics(\n", " train_table,\n", " [segmentation_collector, metrics_fn, embedding_collector],\n", " predictor,\n", " constants={\"epoch\": epoch},\n", " dataloader_args=mc_dataloader_args,\n", " split=\"train\",\n", " )\n", " tlc.collect_metrics(\n", " val_table,\n", " [segmentation_collector, metrics_fn, embedding_collector],\n", " predictor,\n", " constants={\"epoch\": epoch},\n", " dataloader_args=mc_dataloader_args,\n", " split=\"val\",\n", " )\n", "\n", "\n", "# Collect metrics before training (-1 means before training)\n", "collect_metrics(-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Uses the \"weights\" column of the Table to sample the data\n", "sampler = train_table.create_sampler()\n", "\n", "train_dataloader = DataLoader(train_table, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)\n", "valid_dataloader = DataLoader(val_table, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_fn(logits, labels):\n", " upsampled_logits = torch.nn.functional.interpolate(\n", " logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " if model.config.num_labels > 1:\n", " loss_fct = torch.nn.CrossEntropyLoss(ignore_index=model.config.semantic_loss_ignore_index)\n", " loss = loss_fct(upsampled_logits, labels)\n", " elif model.config.num_labels == 1:\n", " valid_mask = ((labels >= 0) & (labels != model.config.semantic_loss_ignore_index)).float()\n", " loss_fct = torch.nn.BCEWithLogitsLoss(reduction=\"none\")\n", " loss = loss_fct(upsampled_logits.squeeze(1), labels.float())\n", " loss = (loss * valid_mask).mean()\n", "\n", " return loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)\n", "\n", "# move model to GPU\n", "model.to(DEVICE)\n", "\n", "model.train()\n", "for epoch in range(EPOCHS): # loop over the dataset multiple times\n", " print(\"Epoch:\", epoch)\n", " agg_loss = 0.0\n", " seen_samples = 0\n", " for _idx, batch in enumerate(tqdm(train_dataloader)):\n", " # get the inputs;\n", " pixel_values = batch[\"pixel_values\"].to(DEVICE)\n", " labels = batch[\"labels\"].to(DEVICE)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward + backward + optimize\n", " outputs = model(pixel_values=pixel_values, labels=labels)\n", " _, logits = outputs.loss, outputs.logits\n", " loss = loss_fn(outputs.logits, labels)\n", "\n", " agg_loss += loss.item() * pixel_values.shape[0]\n", " seen_samples += pixel_values.shape[0]\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # evaluate\n", " with torch.no_grad():\n", " upsampled_logits = torch.nn.functional.interpolate(\n", " logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " predicted = upsampled_logits.argmax(dim=1)\n", "\n", " # Log aggregated metrics directly to the active Run\n", " tlc.log(\n", " {\n", " \"epoch\": epoch,\n", " \"running_train_loss\": loss.item() / seen_samples,\n", " }\n", " )\n", "\n", " if epoch % 50 == 0 and epoch != 0:\n", " collect_metrics(epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect metrics after training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "collect_metrics(epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dimensionality reduce collected metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run.reduce_embeddings_by_foreign_table_url(train_table.url)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }