{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning a SegFormer model with Pytorch Lightning\n", "\n", "This notebook fine-tunes a SegFormer model for semantic segmentation using PyTorch Lightning.\n", "\n", "The original notebook can be found [here](https://colab.research.google.com/drive/1250K828ixr-sG2xzLVYYN5AVa3cTaDEF).\n", "\n", "![](../images/lightning-balloons-semseg.png)\n", "\n", "\n", "\n", "In this tutorial we will see how to fine-tune a pre-trained SegFormer model for semantic segmentation on a custom dataset.\n", "We will integrate with 3LC by creating a training run, registering 3LC datasets, and collecting per-sample predicted masks.\n", "\n", "This notebook demonstrates:\n", "\n", "+ Training a SegFormer model on a custom dataset with Pytorch Lightning.\n", "+ Registering train/val/test sets into 3LC Tables\n", "+ Collecting per-sample semantic segmentation, predicted masks through callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - Image Segmentation\"\n", "RUN_NAME = \"Train Balloon SegFormer\"\n", "DESCRIPTION = \"Train a SegFormer model using PyTorch Lightning\"\n", "TRAIN_DATASET_NAME = \"balloons-train\"\n", "VAL_DATASET_NAME = \"balloons-val\"\n", "TEST_DATASET_NAME = \"balloons-test\"\n", "MODEL = \"nvidia/mit-b0\"\n", "DATA_PATH = \"../../data\"\n", "EPOCHS = 10\n", "BATCH_SIZE = 8\n", "NUM_WORKERS = 0\n", "DEVICE = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc[huggingface]\n", "%pip install pytorch-lightning\n", "%pip install matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import numpy as np\n", "import pytorch_lightning as pl\n", "import tlc\n", "import torch\n", "from evaluate import load\n", "from matplotlib import pyplot as plt\n", "from PIL import Image\n", "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision.datasets import VisionDataset\n", "from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if DEVICE is None:\n", " if torch.cuda.is_available():\n", " device = \"cuda:0\"\n", " elif torch.backends.mps.is_available():\n", " # Disable MPS due to tensor view issues with SegFormer\n", " device = \"cpu\"\n", " else:\n", " device = \"cpu\"\n", "else:\n", " device = DEVICE\n", "\n", "device = torch.device(device)\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Datasets and Training helpers\n", "\n", "We will create a Table with the images and their associated masks.\n", "\n", "Moreover, we will also define helpers to pre-process this dataset into a suitable form for training and collecting metrics.\n", "\n", "To finish, we define a Pytorch LightningModule to define the steps for training, validation and test.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TLCSemanticSegmentationDataset(VisionDataset):\n", " \"\"\"Image (semantic) segmentation dataset.\"\"\"\n", "\n", " def __init__(self, root_dir, transforms=None):\n", " super().__init__(root_dir, transforms=transforms)\n", " self.root_dir = root_dir\n", " image_file_names = [f for f in os.listdir(self.root_dir) if \".jpg\" in f]\n", " mask_file_names = [f for f in os.listdir(self.root_dir) if \".png\" in f]\n", " self.images = sorted(image_file_names)\n", " self.masks = sorted(mask_file_names)\n", "\n", " def __len__(self):\n", " return len(self.images)\n", "\n", " def __getitem__(self, idx):\n", " image = Image.open(os.path.join(self.root_dir, self.images[idx]))\n", " segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))\n", "\n", " if self.transforms is not None:\n", " return self.transforms(image, segmentation_map, image.size, segmentation_map.size)\n", "\n", " return image, segmentation_map, image.size, segmentation_map.size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_id2label(root_dir):\n", " classes_csv_file = os.path.join(root_dir, \"_classes.csv\")\n", " with open(classes_csv_file) as fid:\n", " data = [line.split(\",\") for idx, line in enumerate(fid) if idx != 0]\n", " return {float(x[0]): x[1].strip() for x in data}\n", "\n", "\n", "image_processor = SegformerImageProcessor.from_pretrained(MODEL)\n", "image_processor.do_reduce_labels = False\n", "image_processor.size = 128\n", "dataset_location = tlc.Url(DATA_PATH + \"/balloons-mask-segmentation\").to_absolute()\n", "id2label = get_id2label(f\"{dataset_location}/train/\") # Assuming the same classes for train, val, and test\n", "\n", "model = SegformerForSemanticSegmentation.from_pretrained(\n", " MODEL,\n", " num_labels=len(id2label.keys()),\n", " id2label=id2label,\n", " label2id={v: k for k, v in id2label.items()},\n", " ignore_mismatched_sizes=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "structure = (\n", " tlc.PILImage(\"image\"),\n", " tlc.SegmentationPILImage(\"mask\", id2label),\n", " tlc.HorizontalTuple(\"image size\", [tlc.Int(\"width\"), tlc.Int(\"height\")]),\n", " tlc.HorizontalTuple(\"mask size\", [tlc.Int(\"width\"), tlc.Int(\"height\")]),\n", ")\n", "\n", "\n", "def mc_preprocess_fn(batch, predictor_output):\n", " \"\"\"Transform a batch of inputs and model outputs to a format expected by the metrics collector.\"\"\"\n", "\n", " original_mask_size = batch[\"mask_size\"].tolist()\n", " outputs = predictor_output.forward\n", "\n", " predicted_masks = image_processor.post_process_semantic_segmentation(\n", " outputs=outputs,\n", " target_sizes=original_mask_size,\n", " )\n", " return batch, predicted_masks\n", "\n", "\n", "segmentation_metrics_collector = tlc.SegmentationMetricsCollector(\n", " label_map=id2label,\n", " preprocess_fn=mc_preprocess_fn,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@tlc.lightning_module(\n", " structure=structure,\n", " dataset_prefix=\"balloons\",\n", " run_name=RUN_NAME,\n", " run_description=DESCRIPTION,\n", " if_run_exists=\"overwrite\",\n", " if_dataset_exists=\"overwrite\",\n", " project_name=PROJECT_NAME,\n", " metrics_collectors=[segmentation_metrics_collector],\n", " metrics_collection_interval=10,\n", ")\n", "class SegformerFinetuner(pl.LightningModule):\n", " def __init__(\n", " self,\n", " model,\n", " id2label,\n", " train_dataloader=None,\n", " val_dataloader=None,\n", " test_dataloader=None,\n", " metrics_interval=100,\n", " ):\n", " super().__init__()\n", " self.train_dl = train_dataloader\n", " self.val_dl = val_dataloader\n", " self.test_dl = test_dataloader\n", " self.metrics_interval = metrics_interval\n", "\n", " self.num_classes = len(id2label.keys())\n", "\n", " self.model = model\n", "\n", " self.train_mean_iou = load(\"mean_iou\")\n", " self.val_mean_iou = load(\"mean_iou\")\n", " self.test_mean_iou = load(\"mean_iou\")\n", "\n", " self.training_step_outputs = [] # >=2.0.0 fix\n", " self.validation_step_outputs = [] # >=2.0.0 fix\n", " self.test_step_outputs = [] # >=2.0.0 fix\n", "\n", " def forward(self, images, masks=None):\n", " outputs = self.model(images, masks)\n", " return outputs\n", "\n", " def training_step(self, batch, batch_idx):\n", " images, masks = batch[\"pixel_values\"], batch[\"labels\"]\n", " outputs = self(images, masks)\n", " loss, logits = outputs[0], outputs[1]\n", " upsampled_logits = nn.functional.interpolate(\n", " logits, size=masks.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " predicted = upsampled_logits.argmax(dim=1)\n", " self.train_mean_iou.add_batch(\n", " predictions=predicted.detach().cpu().numpy(),\n", " references=masks.detach().cpu().numpy(),\n", " )\n", " if batch_idx % self.metrics_interval == 0:\n", " metrics = self.train_mean_iou.compute(\n", " num_labels=self.num_classes,\n", " ignore_index=255,\n", " reduce_labels=False,\n", " )\n", " metrics = {\n", " \"loss\": loss,\n", " \"mean_iou\": metrics[\"mean_iou\"],\n", " \"mean_accuracy\": metrics[\"mean_accuracy\"],\n", " }\n", " for k, v in metrics.items():\n", " self.log(k, v, prog_bar=True)\n", "\n", " tlc.log(\n", " {\n", " **{k: v.item() for k, v in metrics.items()},\n", " \"step\": self.global_step,\n", " }\n", " )\n", "\n", " else:\n", " metrics = {\"loss\": loss}\n", "\n", " self.training_step_outputs.append(metrics) # >=2.0.0 fix\n", " return metrics\n", "\n", " def validation_step(self, batch):\n", " images, masks = batch[\"pixel_values\"], batch[\"labels\"]\n", " outputs = self(images, masks)\n", " loss, logits = outputs[0], outputs[1]\n", " upsampled_logits = nn.functional.interpolate(\n", " logits, size=masks.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " predicted = upsampled_logits.argmax(dim=1)\n", " self.val_mean_iou.add_batch(\n", " predictions=predicted.detach().cpu().numpy(),\n", " references=masks.detach().cpu().numpy(),\n", " )\n", " self.validation_step_outputs.append(loss) # >=2.0.0 fix\n", "\n", " return {\"val_loss\": loss}\n", "\n", " def on_validation_epoch_end(self):\n", " metrics = self.val_mean_iou.compute(\n", " num_labels=self.num_classes,\n", " ignore_index=255,\n", " reduce_labels=False,\n", " )\n", "\n", " avg_val_loss = torch.stack(self.validation_step_outputs).mean() # >=2.0.0 fix\n", " val_mean_iou = metrics[\"mean_iou\"]\n", " val_mean_accuracy = metrics[\"mean_accuracy\"]\n", "\n", " metrics = {\n", " \"val_loss\": avg_val_loss,\n", " \"val_mean_iou\": val_mean_iou,\n", " \"val_mean_accuracy\": val_mean_accuracy,\n", " }\n", " for k, v in metrics.items():\n", " self.log(k, v, prog_bar=True)\n", "\n", " self.validation_step_outputs.clear() # >=2.0.0 fix\n", "\n", " if not self.trainer.sanity_checking:\n", " tlc.log(\n", " {\n", " **{k: v.item() for k, v in metrics.items()},\n", " \"step\": self.global_step,\n", " }\n", " )\n", "\n", " return metrics\n", "\n", " def test_step(self, batch):\n", " images, masks = batch[\"pixel_values\"], batch[\"labels\"]\n", " outputs = self(images, masks)\n", " loss, logits = outputs[0], outputs[1]\n", " upsampled_logits = nn.functional.interpolate(\n", " logits, size=masks.shape[-2:], mode=\"bilinear\", align_corners=False\n", " )\n", " predicted = upsampled_logits.argmax(dim=1)\n", " self.test_mean_iou.add_batch(\n", " predictions=predicted.detach().cpu().numpy(),\n", " references=masks.detach().cpu().numpy(),\n", " )\n", " self.test_step_outputs.append(loss) # >=2.0.0 fix\n", "\n", " return {\"test_loss\": loss}\n", "\n", " def on_test_epoch_end(self):\n", " metrics = self.test_mean_iou.compute(\n", " num_labels=self.num_classes,\n", " ignore_index=255,\n", " reduce_labels=False,\n", " )\n", "\n", " avg_test_loss = torch.stack(self.test_step_outputs).mean() # >=2.0.0 fix\n", " test_mean_iou = metrics[\"mean_iou\"]\n", " test_mean_accuracy = metrics[\"mean_accuracy\"]\n", " metrics = {\n", " \"test_loss\": avg_test_loss,\n", " \"test_mean_iou\": test_mean_iou,\n", " \"test_mean_accuracy\": test_mean_accuracy,\n", " }\n", " for k, v in metrics.items():\n", " self.log(k, v)\n", " self.test_step_outputs.clear() # >=2.0.0 fix\n", "\n", " return metrics\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)\n", "\n", " def train_dataloader(self):\n", " return self.train_dl\n", "\n", " def val_dataloader(self):\n", " return self.val_dl\n", "\n", " def test_dataloader(self):\n", " return self.test_dl" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def transforms(image, mask, image_size, mask_size):\n", " encoded_inputs = image_processor(image, mask, return_tensors=\"pt\")\n", " for k, _ in encoded_inputs.items():\n", " encoded_inputs[k] = encoded_inputs[k].squeeze() # remove batch dimension\n", "\n", " # Add the original mask size to the batch so that we can resize the mask back to its original size later\n", " encoded_inputs.update({\"mask_size\": torch.tensor(mask_size)})\n", "\n", " return encoded_inputs\n", "\n", "\n", "train_dataset = TLCSemanticSegmentationDataset(f\"{dataset_location}/train/\", transforms)\n", "val_dataset = TLCSemanticSegmentationDataset(f\"{dataset_location}/valid/\", transforms)\n", "test_dataset = TLCSemanticSegmentationDataset(f\"{dataset_location}/test/\", transforms)\n", "\n", "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)\n", "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)\n", "test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)\n", "\n", "segformer_finetuner = SegformerFinetuner(\n", " model,\n", " id2label,\n", " train_dataloader=train_dataloader,\n", " val_dataloader=val_dataloader,\n", " test_dataloader=test_dataloader,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "early_stop_callback = EarlyStopping(\n", " monitor=\"val_loss\",\n", " patience=5,\n", " verbose=True,\n", ")\n", "\n", "checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor=\"val_loss\", save_last=True)\n", "\n", "trainer = pl.Trainer(\n", " accelerator=\"cpu\", # Changed from \"gpu\" to \"cpu\" to avoid MPS tensor issues\n", " callbacks=[early_stop_callback, checkpoint_callback],\n", " max_epochs=EPOCHS,\n", " val_check_interval=len(train_dataloader),\n", " log_every_n_steps=7,\n", ")\n", "\n", "trainer.fit(segformer_finetuner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Checking results " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = trainer.test(ckpt_path=\"last\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "batch = next(iter(test_dataloader))\n", "images, masks = batch[\"pixel_values\"].to(device), batch[\"labels\"].to(device)\n", "segformer_finetuner.eval().to(device)\n", "outputs = segformer_finetuner.model(images, masks, return_dict=True)\n", "batch_prediction = image_processor.post_process_semantic_segmentation(outputs, batch[\"mask_size\"].tolist())\n", "\n", "n_rows = len(images)\n", "n_cols = 3\n", "fig_width = n_cols * 5\n", "fig_height = n_rows * 5\n", "fig, ax = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))\n", "fig.suptitle(\"Test Batch Predictions\", fontsize=16)\n", "plt.tight_layout(pad=3.0, h_pad=-1.0, w_pad=1.0, rect=[0, 0, 1, 1])\n", "\n", "for i in range(n_rows):\n", " for j in range(3):\n", " ax[i, j].axis(\"off\")\n", "\n", " ax[i, 0].imshow(masks[i, :, :].cpu().numpy(), cmap=\"gray\")\n", " ax[i, 0].set_title(f\"Ground Truth (id={i})\", fontsize=14)\n", "\n", " ax[i, 1].imshow(batch_prediction[i].cpu().numpy(), cmap=\"gray\")\n", " ax[i, 1].set_title(\"Predicted mask (latest model)\", fontsize=14)\n", "\n", " im = tlc.active_run().metrics_tables[-1][i][\"predicted_mask\"]\n", " ax[i, 2].imshow(np.array(im), cmap=\"gray\")\n", " ax[i, 2].set_title(\"Predicted mask (3LC metrics)\", fontsize=14)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 0 }