{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Adding Object Detections to a Dataset\n", "\n", "This recipe provides a glimpse into the possibilities for integrating FiftyOne into your ML workflows. Specifically, it covers:\n", "\n", "- Loading an object detection dataset from the [Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_zoo/index.html)\n", "- [Adding predictions](https://voxel51.com/docs/fiftyone/user_guide/using_datasets.html#object-detection) from an object detector to the dataset\n", "- Launching the [FiftyOne App](https://voxel51.com/docs/fiftyone/user_guide/app.html) and visualizing/exploring your data\n", "- Integrating the App into your data analysis workflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "If you haven't already, install FiftyOne:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install fiftyone" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we'll use an off-the-shelf [Faster R-CNN detection model](https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn) provided by PyTorch. To use it, you'll need to install `torch` and `torchvision`, if necessary." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install torch torchvision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading a detection dataset\n", "\n", "In this recipe, we'll work with the validation split of the [COCO dataset](https://cocodataset.org/#home), which is conveniently available for download via the [FiftyOne Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_zoo/datasets.html#coco-2017).\n", "\n", "The snippet below will download the validation split and load it into FiftyOne." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Split 'validation' already downloaded\n", "Loading 'coco-2017' split 'validation'\n", " 100% |████████████████████| 5000/5000 [43.3s elapsed, 0s remaining, 114.9 samples/s] \n", "Dataset 'detector-recipe' created\n" ] } ], "source": [ "import fiftyone as fo\n", "import fiftyone.zoo as foz\n", "\n", "dataset = foz.load_zoo_dataset(\n", " \"coco-2017\",\n", " split=\"validation\",\n", " dataset_name=\"detector-recipe\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect the dataset to see what we downloaded:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Name: detector-recipe\n", "Media type: image\n", "Num samples: 5000\n", "Persistent: False\n", "Info: {'classes': ['0', 'person', 'bicycle', ...]}\n", "Tags: ['validation']\n", "Sample fields:\n", " filepath: fiftyone.core.fields.StringField\n", " tags: fiftyone.core.fields.ListField(fiftyone.core.fields.StringField)\n", " metadata: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.metadata.Metadata)\n", " ground_truth: fiftyone.core.fields.EmbeddedDocumentField(fiftyone.core.labels.Detections)\n" ] } ], "source": [ "# Print some information about the dataset\n", "print(dataset)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Print a ground truth detection\n", "sample = dataset.first()\n", "print(sample.ground_truth.detections[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the ground truth detections are stored in the `ground_truth` field of the samples.\n", "\n", "Before we go further, let's launch the [FiftyOne App](https://voxel51.com/docs/fiftyone/user_guide/app.html) and use the GUI to explore the dataset visually:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", " \n", "
\n", " \n", "
\n", "\n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "session = fo.launch_app(dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adding model predictions\n", "\n", "Now let's add some predictions from an object detector to the dataset.\n", "\n", "We'll use an off-the-shelf [Faster R-CNN detection model](https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn) provided by PyTorch. The following cell downloads the model and loads it:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model ready\n" ] } ], "source": [ "import torch\n", "import torchvision\n", "\n", "# Run the model on GPU if it is available\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load a pre-trained Faster R-CNN model\n", "model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", "model.to(device)\n", "model.eval()\n", "\n", "print(\"Model ready\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The code below performs inference with the model on a randomly chosen subset of 100 samples from the dataset and [stores the predictions](https://voxel51.com/docs/fiftyone/user_guide/using_datasets.html#object-detection) in a `predictions` field of the samples. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Choose a random subset of 100 samples to add predictions to\n", "predictions_view = dataset.take(100, seed=51)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 100% |██████████████████████| 100/100 [12.7m elapsed, 0s remaining, 0.1 samples/s] \n" ] } ], "source": [ "from PIL import Image\n", "from torchvision.transforms import functional as func\n", "\n", "import fiftyone as fo\n", "\n", "# Get class list\n", "classes = dataset.default_classes\n", "\n", "# Add predictions to samples\n", "with fo.ProgressBar() as pb:\n", " for sample in pb(predictions_view):\n", " # Load image\n", " image = Image.open(sample.filepath)\n", " image = func.to_tensor(image).to(device)\n", " c, h, w = image.shape\n", " \n", " # Perform inference\n", " preds = model([image])[0]\n", " labels = preds[\"labels\"].cpu().detach().numpy()\n", " scores = preds[\"scores\"].cpu().detach().numpy()\n", " boxes = preds[\"boxes\"].cpu().detach().numpy()\n", " \n", " # Convert detections to FiftyOne format\n", " detections = []\n", " for label, score, box in zip(labels, scores, boxes):\n", " # Convert to [top-left-x, top-left-y, width, height]\n", " # in relative coordinates in [0, 1] x [0, 1]\n", " x1, y1, x2, y2 = box\n", " rel_box = [x1 / w, y1 / h, (x2 - x1) / w, (y2 - y1) / h]\n", "\n", " detections.append(\n", " fo.Detection(\n", " label=classes[label],\n", " bounding_box=rel_box,\n", " confidence=score\n", " )\n", " )\n", " \n", " # Save predictions to dataset\n", " sample[\"predictions\"] = fo.Detections(detections=detections)\n", " sample.save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load `predictions_view` in the App to visualize the predictions that we added:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", " \n", "
\n", "