{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "QHnVupBBn9eR" }, "source": [ "# Training and Evaluating FiftyOne Datasets with Detectron2\n", "\n", "FiftyOne has all of the building blocks necessary to develop high-quality datasets to train your models, as well as advanced model evaluation capabilities. To make use of these, FiftyOne easily integrates with your existing model training and inference pipelines. In this walktrhough we'll cover how you can use your FiftyOne datasets to train a model with [Detectron2](https://github.com/facebookresearch/detectron2), Facebook AI Reasearch's library for detection and segmentation algorithms.\n", "\n", "This walkthrough is based off of the [official Detectron2 tutorial](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5), augmented to load data to and from FiftyOne.\n", "\n", "\n", "Specifically, this walkthrough covers:\n", "\n", "- Loading a dataset from the FiftyOne Zoo, and splitting it into training/validation\n", "- Initializing a segmentation model from the detectron2 model zoo\n", "- Loading ground truth annotations from a FiftyOne dataset into a detectron2 model training pipeline and training the model\n", "- Loading predictions from a detectron2 model into a FiftyOne dataset\n", "- Evaluating model predictions in FiftyOne\n", "\n", "\n", "**So, what’s the takeaway?**\n", "\n", "By writing two simple functions, you can integrate FiftyOne into your Detectron2 model training and inference pipelines." ] }, { "cell_type": "markdown", "metadata": { "id": "sHAIvEZsgxyH" }, "source": [ "## Setup\n", "\n", "To get started, you need to install [FiftyOne](https://voxel51.com/docs/fiftyone/getting_started/install.html) and [detectron2](https://detectron2.readthedocs.io/en/latest/tutorials/install.html):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lG2a9JXZgvSh" }, "outputs": [], "source": [ "!pip install fiftyone" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5xCYXOy_g2uh" }, "outputs": [], "source": [ "import fiftyone as fo\n", "import fiftyone.zoo as foz" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FsePPpwZSmqt" }, "outputs": [], "source": [ "!python -m pip install pyyaml==5.1\n", "\n", "# Detectron2 has not released pre-built binaries for the latest pytorch (https://github.com/facebookresearch/detectron2/issues/4053)\n", "# so we install from source instead. This takes a few minutes.\n", "!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'\n", "\n", "# Install pre-built detectron2 that matches pytorch version, if released:\n", "# See https://detectron2.readthedocs.io/tutorials/install.html for instructions\n", "#!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VERSION}/{TORCH_VERSION}/index.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0d288Z2mF5dC", "outputId": "12d42152-f2d7-4d71-82dd-ff7ea3b0b1a8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nvcc: NVIDIA (R) Cuda compiler driver\n", "Copyright (c) 2005-2020 NVIDIA Corporation\n", "Built on Mon_Oct_12_20:09:46_PDT_2020\n", "Cuda compilation tools, release 11.1, V11.1.105\n", "Build cuda_11.1.TC455_06.29190527_0\n", "torch: 1.12 ; cuda: cu113\n", "detectron2: 0.6\n" ] } ], "source": [ "import torch, detectron2\n", "!nvcc --version\n", "TORCH_VERSION = \".\".join(torch.__version__.split(\".\")[:2])\n", "CUDA_VERSION = torch.__version__.split(\"+\")[-1]\n", "print(\"torch: \", TORCH_VERSION, \"; cuda: \", CUDA_VERSION)\n", "print(\"detectron2:\", detectron2.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZyAvNCJMmvFF" }, "outputs": [], "source": [ "# Setup detectron2 logger\n", "import detectron2\n", "from detectron2.utils.logger import setup_logger\n", "setup_logger()\n", "\n", "# import some common libraries\n", "import numpy as np\n", "import os, cv2\n", "\n", "# import some common detectron2 utilities\n", "from detectron2 import model_zoo\n", "from detectron2.engine import DefaultPredictor\n", "from detectron2.config import get_cfg\n", "from detectron2.data import MetadataCatalog, DatasetCatalog" ] }, { "cell_type": "markdown", "metadata": { "id": "b2bjrfb2LDeo" }, "source": [ "## Train on a FiftyOne dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "tjbUIhSxUdm_" }, "source": [ "In this section, we show how to use a custom FiftyOne Dataset to train a detectron2 model.\n", "We'll train a license plate segmentation model from an existing model pre-trained on COCO dataset, available in detectron2's model zoo.\n", "\n", "Since the COCO dataset doesn't have a \"Vehicle registration plates\" category, we will be using segmentations of license plates from the Open Images v6 dataset in the [FiftyOne Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_zoo/datasets.html#open-images-v6) to train the model to recognize this new category.\n", "\n", "\n", "## Prepare the dataset\n", "\n", "For this example, we will just use some of the samples from the official \"validation\" split of the dataset. To improve model performance, we could always add in more data from the official \"train\" split as well but that will take longer to train so we'll just stick to the \"validation\" split for this walkthrough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f1SnOsM9ijEq", "outputId": "6b8f7cde-951f-49f1-f5f1-b8deecf6f207" }, "outputs": [], "source": [ "dataset = foz.load_zoo_dataset(\n", " \"open-images-v6\", \n", " split=\"validation\", \n", " classes=[\"Vehicle registration plate\"], \n", " label_types=[\"segmentations\"],\n", " label_field=\"segmentations\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specifying a `classes` when downloading a dataset from the zoo will ensure that only samples with one of the given classes will be present. However, these samples may still contain other labels, so we can use the powerful [filtering capability](https://voxel51.com/docs/fiftyone/user_guide/using_views.html#filtering) of FiftyOne to easily keep only the \"Vehicle registration plate\" labels.\n", "We will also untag these samples as \"validation\" and create our own split out of them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_PlG3hEuhVW0" }, "outputs": [], "source": [ "from fiftyone import ViewField as F\n", "\n", "# Remove other classes and existing tags\n", "dataset.filter_labels(\"segmentations\", F(\"label\") == \"Vehicle registration plate\").save()\n", "dataset.untag_samples(\"validation\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gnFQLpvROpoL" }, "outputs": [], "source": [ "import fiftyone.utils.random as four\n", "\n", "four.random_split(dataset, {\"train\": 0.8, \"val\": 0.2})" ] }, { "cell_type": "markdown", "metadata": { "id": "tVJoOm6LVJwW" }, "source": [ "Next we will register the FiftyOne dataset to detectron2, following the [detectron2 custom dataset tutorial](https://detectron2.readthedocs.io/tutorials/datasets.html).\n", "Here, the dataset is in its custom format, therefore we write a function to parse it and prepare it into [detectron2's standard format](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#standard-dataset-dicts).\n", "\n", "Note: In this example, we are specifically parsing the segmentations into bounding boxes and polylines. This function may require tweaks depending on the model being trained and the data it expects.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PIbAM2pv-urF" }, "outputs": [], "source": [ "from detectron2.structures import BoxMode\n", "\n", "def get_fiftyone_dicts(samples):\n", " samples.compute_metadata()\n", "\n", " dataset_dicts = []\n", " for sample in samples.select_fields([\"id\", \"filepath\", \"metadata\", \"segmentations\"]):\n", " height = sample.metadata[\"height\"]\n", " width = sample.metadata[\"width\"]\n", " record = {}\n", " record[\"file_name\"] = sample.filepath\n", " record[\"image_id\"] = sample.id\n", " record[\"height\"] = height\n", " record[\"width\"] = width\n", " \n", " objs = []\n", " for det in sample.segmentations.detections:\n", " tlx, tly, w, h = det.bounding_box\n", " bbox = [int(tlx*width), int(tly*height), int(w*width), int(h*height)]\n", " fo_poly = det.to_polyline()\n", " poly = [(x*width, y*height) for x, y in fo_poly.points[0]]\n", " poly = [p for x in poly for p in x]\n", " obj = {\n", " \"bbox\": bbox,\n", " \"bbox_mode\": BoxMode.XYWH_ABS,\n", " \"segmentation\": [poly],\n", " \"category_id\": 0,\n", " }\n", " objs.append(obj)\n", "\n", " record[\"annotations\"] = objs\n", " dataset_dicts.append(record)\n", "\n", " return dataset_dicts\n", "\n", "for d in [\"train\", \"val\"]:\n", " view = dataset.match_tags(d)\n", " DatasetCatalog.register(\"fiftyone_\" + d, lambda view=view: get_fiftyone_dicts(view))\n", " MetadataCatalog.get(\"fiftyone_\" + d).set(thing_classes=[\"vehicle_registration_plate\"])\n", "\n", "metadata = MetadataCatalog.get(\"fiftyone_train\")" ] }, { "cell_type": "markdown", "metadata": { "id": "6ljbWTX0Wi8E" }, "source": [ "To verify the dataset is in correct format, let's visualize the annotations of the training set:\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OLsNgm-Zv-zM", "outputId": "cd200349-cb35-40e9-e253-cff402e98ded" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] }, { "data": { "text/html": [ "\n", "\n", "\n", "