{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "0", "metadata": { "id": "QHnVupBBn9eR" }, "source": [ "# Fine-tune a object detection model using Detectron2\n", "\n", "This notebook shows how to collect bounding box metrics while training a model using Detectron2.\n", "\n", "It is a modified version of the official colab tutorial of detectron which can be found [here](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5).\n", "\n", "![](../images/balloons-det-d2.png)\n", "\n", "\n", "\n", "In this tutorial we will see how to fine-tune a pre-trained detectron model for object detection on a custom dataset in the COCO format.\n", "We will integrate with 3LC by creating a training run, registering 3LC datasets, and collecting per-sample bounding box metrics.\n", "\n", "This notebook demonstrates:\n", "\n", "+ Training a detectron2 model on a custom dataset.\n", "+ Integrating a COCO dataset with 3LC using `register_coco_instances()`.\n", "+ Collecting per-sample bounding box metrics using `BoundingBoxMetricsCollector`.\n", "+ Registering a custom per-sample metrics collection callback." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Setup Project" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - Balloons\"\n", "RUN_NAME = \"Fine-tune balloon detector\"\n", "DESCRIPTION = \"Train a balloon detector using detectron2\"\n", "TRAIN_DATASET_NAME = \"balloons-train\"\n", "VAL_DATASET_NAME = \"balloons-val\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "DATA_PATH = \"../../data\"\n", "MODEL_CONFIG = \"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml\"\n", "MAX_ITERS = 200\n", "BATCH_SIZE = 2\n", "MAX_DETECTIONS_PER_IMAGE = 30\n", "SCORE_THRESH_TEST = 0.5\n", "INSTALL_DEPENDENCIES = True" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "if INSTALL_DEPENDENCIES:\n", " # NOTE: There is no single version of detectron2 that is appropriate for all users and all systems.\n", " # This notebook uses a particular prebuilt version of detectron2 that is only available for\n", " # Linux and for specific versions of torch, torchvision, and CUDA. It may not be appropriate\n", " # for your system. See https://detectron2.readthedocs.io/en/latest/tutorials/install.html for\n", " # instructions on how to install or build a version of detectron2 for your system.\n", " %pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html\n", " %pip install detectron2 -f \"https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html\"\n", " %pip install 3lc\n", " %pip install opencv-python\n", " %pip install matplotlib\n", " %pip install numpy==1.24.4" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": { "id": "ZyAvNCJMmvFF" }, "outputs": [], "source": [ "import os\n", "import random\n", "from pathlib import Path\n", "\n", "import cv2\n", "import matplotlib.pyplot as plt\n", "import tlc\n", "from detectron2 import model_zoo\n", "from detectron2.config import get_cfg\n", "from detectron2.data import DatasetCatalog, MetadataCatalog\n", "from detectron2.utils.logger import setup_logger\n", "from detectron2.utils.visualizer import Visualizer\n", "\n", "logger = setup_logger()\n", "logger.setLevel(\"ERROR\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6", "metadata": { "id": "b2bjrfb2LDeo" }, "source": [ "## Prepare the dataset\n", "\n", "In this section, we show how to train an existing detectron2 model on a custom dataset in the COCO format.\n", "\n", "We use [the balloon segmentation dataset](https://github.com/matterport/Mask_RCNN/tree/master/samples/balloon)\n", "which only has one class: balloon. \n", "\n", "You can find a [modified COCO version](https://github.com/3lc-ai/3lc-examples/tree/main/data/balloons) of this dataset inside the \"data\" directory provided while cloning our [repository](https://github.com/3lc-ai/notebook-examples/).\n", "\n", "We'll train a balloon segmentation model from an existing model pre-trained on COCO dataset, available in detectron2's model zoo.\n", "\n", "Note that COCO dataset does not have the \"balloon\" category. We'll be able to recognize this new class in a few minutes." ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "BALLOONS_ROOT = (Path(DATA_PATH) / \"balloons\").resolve().absolute()\n", "assert BALLOONS_ROOT.exists()\n", "\n", "train_json_path = BALLOONS_ROOT / \"train/train-annotations.json\"\n", "train_image_folder = BALLOONS_ROOT / \"train\"\n", "val_json_path = BALLOONS_ROOT / \"val/val-annotations.json\"\n", "val_image_folder = BALLOONS_ROOT / \"val\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8", "metadata": {}, "source": [ "## Register the dataset with 3LC\n", "\n", "Now that we have the dataset in the COCO format, we can register it with 3LC." ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "from tlc.integration.detectron2 import register_coco_instances\n", "\n", "register_coco_instances(\n", " TRAIN_DATASET_NAME,\n", " {},\n", " str(train_json_path),\n", " str(train_image_folder),\n", " project_name=PROJECT_NAME,\n", ")\n", "\n", "register_coco_instances(\n", " VAL_DATASET_NAME,\n", " {},\n", " str(val_json_path),\n", " str(val_image_folder),\n", " project_name=PROJECT_NAME,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": { "tags": [] }, "outputs": [], "source": [ "# The detectron2 dataset dicts and dataset metadata can be read from the DatasetCatalog and\n", "# MetadataCatalog, respectively.\n", "dataset_metadata = MetadataCatalog.get(TRAIN_DATASET_NAME)\n", "dataset_dicts = DatasetCatalog.get(TRAIN_DATASET_NAME)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "11", "metadata": { "id": "6ljbWTX0Wi8E" }, "source": [ "To verify the dataset is in correct format, let's visualize the annotations of randomly selected samples in the training set:\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "UkNbUzUOLYf0", "outputId": "4f5ed932-624a-4ede-9d5b-22371569fe1d" }, "outputs": [], "source": [ "import numpy as np\n", "from detectron2.utils.file_io import PathManager\n", "\n", "for d in random.sample(dataset_dicts, 3):\n", " filename = tlc.Url(d[\"file_name\"]).to_absolute().to_str()\n", " if \"s3://\" in filename:\n", " with PathManager.open(filename, \"rb\") as f:\n", " img = np.asarray(bytearray(f.read()), dtype=\"uint8\")\n", " img = cv2.imdecode(img, cv2.IMREAD_COLOR)\n", " else:\n", " img = cv2.imread(filename)\n", " visualizer = Visualizer(img[:, :, ::-1], metadata=dataset_metadata, scale=0.5)\n", " out = visualizer.draw_dataset_dict(d)\n", " out_rgb = cv2.cvtColor(out.get_image(), cv2.COLOR_BGR2RGB)\n", " plt.imshow(out_rgb[:, :, ::-1])\n", " plt.title(filename.split(\"/\")[-1])\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "## Create a custom metrics collection function\n", "\n", "We will use a BoundingBoxMetricsCollection to collect per-sample bounding box metrics.\n", "This allows users to supply a custom function to collect the metrics." ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "def custom_bbox_metrics_collector(gts, preds, metrics):\n", " \"\"\"Example function that computes custom metrics for bounding box detection.\"\"\"\n", "\n", " # Lets just return the number of ground truth boxes and predictions\n", " num_gts = [len(gt[\"annotations\"]) for gt in gts]\n", " num_preds = [len(pred[\"annotations\"]) for pred in preds]\n", "\n", " metrics[\"num_gts\"] = num_gts\n", " metrics[\"num_preds\"] = num_preds" ] }, { "attachments": {}, "cell_type": "markdown", "id": "15", "metadata": { "id": "wlqXIXXhW8dA" }, "source": [ "## Train!\n", "\n", "Now, let's fine-tune a COCO-pretrained R50-FPN Mask R-CNN model on the balloon dataset. It takes ~2 minutes to train 300 iterations on a P100 GPU." ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "run = tlc.init(\n", " project_name=PROJECT_NAME,\n", " run_name=RUN_NAME,\n", " description=DESCRIPTION,\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "# For a full list of config values: https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/defaults.py\n", "cfg = get_cfg()\n", "cfg.merge_from_file(model_zoo.get_config_file(MODEL_CONFIG))\n", "cfg.DATASETS.TRAIN = (TRAIN_DATASET_NAME,)\n", "cfg.DATASETS.TEST = (VAL_DATASET_NAME,)\n", "cfg.DATALOADER.NUM_WORKERS = 0\n", "cfg.OUTPUT_DIR = DOWNLOAD_PATH\n", "cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(MODEL_CONFIG) # Let training initialize from model zoo\n", "cfg.SOLVER.IMS_PER_BATCH = BATCH_SIZE # This is the real \"batch size\" commonly known to deep learning people\n", "cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR\n", "cfg.SOLVER.MAX_ITER = (\n", " MAX_ITERS # Seems good enough for this toy dataset; you will need to train longer for a practical dataset\n", ")\n", "cfg.SOLVER.STEPS = [] # Do not decay learning rate\n", "cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (\n", " 128 # The \"RoIHead batch size\". 128 is faster, and good enough for this toy dataset (default: 512)\n", ")\n", "cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # Only has one class (balloon).\n", "\n", "cfg.TEST.DETECTIONS_PER_IMAGE = MAX_DETECTIONS_PER_IMAGE\n", "cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESH_TEST\n", "cfg.MODEL.DEVICE = \"cuda\"\n", "cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False\n", "\n", "os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)\n", "\n", "config = {\n", " \"model_config\": MODEL_CONFIG,\n", " \"solver.ims_per_batch\": BATCH_SIZE,\n", " \"test.detections_per_image\": MAX_DETECTIONS_PER_IMAGE,\n", " \"model.roi_heads.score_thresh_test\": SCORE_THRESH_TEST,\n", "}\n", "\n", "run.set_parameters(config)" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7unkuuiqLdqd", "outputId": "ba1716cd-3f3b-401d-bae5-8fbbd2199d9c" }, "outputs": [], "source": [ "from detectron2.engine import DefaultTrainer\n", "from tlc.integration.detectron2 import DetectronMetricsCollectionHook, MetricsCollectionHook\n", "\n", "trainer = DefaultTrainer(cfg)\n", "\n", "metrics_collector = tlc.BoundingBoxMetricsCollector(\n", " classes=dataset_metadata.thing_classes,\n", " label_mapping=dataset_metadata.thing_dataset_id_to_contiguous_id,\n", " extra_metrics_fn=custom_bbox_metrics_collector,\n", ")\n", "\n", "# Add schemas for the custom metrics defined above\n", "metrics_collector.add_schema(\"num_gts\", tlc.Int32Schema(description=\"The number of ground truth boxes\"))\n", "metrics_collector.add_schema(\"num_preds\", tlc.Int32Schema(description=\"The number of predicted boxes\"))\n", "\n", "# Register the metrics collector with the trainer;\n", "# + Collect metrics on the training set every 50 iterations starting at iteration 0\n", "# + Collect metrics on the validation set after training\n", "# + Collect default detectron2 metrics every 5 iterations\n", "trainer.register_hooks(\n", " [\n", " MetricsCollectionHook(\n", " dataset_name=TRAIN_DATASET_NAME,\n", " metrics_collectors=[metrics_collector],\n", " collection_frequency=50,\n", " collection_start_iteration=0,\n", " collect_metrics_after_train=True,\n", " ),\n", " MetricsCollectionHook(\n", " dataset_name=VAL_DATASET_NAME,\n", " metrics_collectors=[metrics_collector],\n", " collect_metrics_after_train=True,\n", " ),\n", " DetectronMetricsCollectionHook(\n", " collection_frequency=5,\n", " ),\n", " ]\n", ")\n", "trainer.resume_or_load(resume=False)\n", "trainer.train()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }