{ "cells": [ { "cell_type": "markdown", "id": "76201efe", "metadata": {}, "source": [ "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " Try in Google Colab\n", " \n", " \n", " \n", " \n", " Share via nbviewer\n", " \n", " \n", " \n", " \n", " View on GitHub\n", " \n", " \n", " \n", " \n", " Download notebook\n", " \n", "
\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1662bb7c", "metadata": {}, "source": [ "# Object masks from prompts with SAM, OpenVINO, and FiftyOne\n" ] }, { "cell_type": "markdown", "id": "e368a0e4", "metadata": {}, "source": [ "**Note**: This notebook is adapted from the [SAM OpenVINO notebook](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/237-segment-anything/237-segment-anything.ipynb)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "7fcc21a0", "metadata": {}, "source": [ "Segmentation - identifying which image pixels belong to an object - is a core task in computer vision and is used in a broad array of applications, from analyzing scientific imagery to editing photos. But creating an accurate segmentation model for specific tasks typically requires highly specialized work by technical experts with access to AI training infrastructure and large volumes of carefully annotated in-domain data. Reducing the need for task-specific modeling expertise, training compute, and custom data annotation for image segmentation is the main goal of the [Segment Anything](https://arxiv.org/abs/2304.02643) project.\n", "\n", "The [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) predicts object masks given prompts that indicate the desired object. SAM has learned a general notion of what objects are, and it can generate masks for any object in any image or any video, even including objects and image types that it had not encountered during training. SAM is general enough to cover a broad set of use cases and can be used out of the box on new image “domains” (e.g. underwater photos, MRI or cell microscopy) without requiring additional training (a capability often referred to as zero-shot transfer).\n", "This notebook shows an example of how to convert and use Segment Anything Model in OpenVINO format, allowing it to run on a variety of platforms that support an OpenVINO.\n", "\n", "## Background\n", "\n", "Previously, to solve any kind of segmentation problem, there were two classes of approaches. The first, interactive segmentation, allowed for segmenting any class of object but required a person to guide the method by iterative refining a mask. The second, automatic segmentation, allowed for segmentation of specific object categories defined ahead of time (e.g., cats or chairs) but required substantial amounts of manually annotated objects to train (e.g., thousands or even tens of thousands of examples of segmented cats), along with the compute resources and technical expertise to train the segmentation model. Neither approach provided a general, fully automatic approach to segmentation.\n", "\n", "Segment Anything Model is a generalization of these two classes of approaches. It is a single model that can easily perform both interactive segmentation and automatic segmentation.\n", "The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. The model consists of 3 parts:\n", "\n", "- **Image Encoder** - Vision Transformer model (VIT) pretrained using Masked Auto Encoders approach (MAE) for encoding image to embedding space. The image encoder runs once per image and can be applied prior to prompting the model.\n", "- **Prompt Encoder** - Encoder for segmentation condition. As a condition can be used:\n", " - points - set of points related to object which should be segmented. Prompt encoder converts points to embedding using positional encoding.\n", " - boxes - bounding box where object for segmentation is located. Similar to points, coordinates of bounding box encoded via positional encoding.\n", " - segmentation mask - provided by user segmentation mask is embedded using convolutions and summed element-wise with the image embedding.\n", " - text - encoded by CLIP model text representation\n", "- **Mask Decoder** - The mask decoder efficiently maps the image embedding, prompt embeddings, and an output token to a mask.\n", "\n", "The diagram below demonstrates the process of mask generation using SAM:\n", "![model_diagram](https://raw.githubusercontent.com/facebookresearch/segment-anything/main/assets/model_diagram.png)\n", "\n", "The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. The model returns multiple masks which fit to the provided prompt and its score. The provided masks can be overlapped areas as it shown on diagram, it is useful for complicated cases when prompt can be interpreted in different manner, e.g. segment whole object or only its specific part or when provided point at the intersection of multiple objects. The model’s promptable interface allows it to be used in flexible ways that make a wide range of segmentation tasks possible simply by engineering the right prompt for the model (clicks, boxes, text, and so on).\n", "\n", "More details about approach can be found in the [paper](https://arxiv.org/abs/2304.02643), original [repo](https://github.com/facebookresearch/segment-anything) and [Meta AI blog post](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "55ae4e00", "metadata": {}, "source": [ "## Prerequisites\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "14346a95", "metadata": {}, "source": [ "In addition to OpenVINO and Meta AI's Segment Anything library, we will use the open source library [FiftyOne](https://docs.voxel51.com/) for visualizing and evaluating our Segment Anything Model predictions.\n" ] }, { "cell_type": "code", "execution_count": 147, "id": "133014d4-6766-48c5-94b6-9c78b9dfe309", "metadata": {}, "outputs": [], "source": [ "!pip install -q \"segment_anything\" \"fiftyone\" \"openvino\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "bd0f6b2b", "metadata": {}, "source": [ "## Convert model to OpenVINO Intermediate Representation\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5c0e935b-b801-46c4-b69d-b508b07425cd", "metadata": {}, "source": [ "### Download model checkpoint and create PyTorch model\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1540f719", "metadata": {}, "source": [ "There are several Segment Anything Model [checkpoints](https://github.com/facebookresearch/segment-anything#model-checkpoints) available for downloading\n", "In this tutorial we will use model based on `vit_b`, but the demonstrated approach is very general and applicable to other SAM models.\n", "Set the model url, path for saving checkpoint and model type below to a SAM model checkpoint, then load the model using `sam_model_registry`.\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "76fc53f4", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../utils\")\n", "from notebook_utils import download_file\n", "\n", "checkpoint = \"sam_vit_b_01ec64.pth\"\n", "model_url = \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\"\n", "model_type = \"vit_b\"\n", "\n", "download_file(model_url)" ] }, { "cell_type": "code", "execution_count": 3, "id": "11bfc8aa", "metadata": { "tags": [] }, "outputs": [], "source": [ "from segment_anything import sam_model_registry\n", "\n", "sam = sam_model_registry[model_type](checkpoint=checkpoint)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1249127f-24dc-4b8f-875a-dbd56887c1fc", "metadata": {}, "source": [ "As we already discussed, Image Encoder part can be used once per image, then changing prompt, prompt encoder and mask decoder can be run multiple times to retrieve different objects from the same image. Taking into account this fact, we split model on 2 independent parts: image_encoder and mask_predictor (combination of Prompt Encoder and Mask Decoder).\n", "\n", "### Image Encoder\n", "\n", "Image Encoder input is tensor with shape `1x3x1024x1024` in `NCHW` format, contains image for segmentation.\n", "Image Encoder output is image embeddings, tensor with shape `1x256x64x64`\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "5b060bdd-dbfe-4466-8f81-09684a0f4204", "metadata": { "tags": [] }, "outputs": [], "source": [ "import warnings\n", "from pathlib import Path\n", "import torch\n", "from openvino.tools import mo\n", "from openvino.runtime import serialize, Core\n", "\n", "core = Core()\n", "\n", "ov_encoder_path = Path(\"sam_image_encoder.xml\")\n", "\n", "if not ov_encoder_path.exists():\n", " onnx_encoder_path = ov_encoder_path.with_suffix(\".onnx\")\n", " if not onnx_encoder_path.exists():\n", " with warnings.catch_warnings():\n", " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "\n", " torch.onnx.export(\n", " sam.image_encoder, torch.zeros(1, 3, 1024, 1024), onnx_encoder_path\n", " )\n", "\n", " ov_encoder_model = mo.convert_model(onnx_encoder_path, compress_to_fp16=True)\n", " serialize(ov_encoder_model, str(ov_encoder_path))\n", "else:\n", " ov_encoder_model = core.read_model(ov_encoder_path)\n", "ov_encoder = core.compile_model(ov_encoder_model)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9951b5f0-66a0-46c8-bd5f-486051f6e398", "metadata": {}, "source": [ "### Mask predictor\n", "\n", "This notebook expects the model was exported with the parameter `return_single_mask=True`. It means that model will only return the best mask, instead of returning multiple masks. For high resolution images this can improve runtime when upscaling masks is expensive.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c2aec47b-dff7-4189-be5a-d25564061ead", "metadata": { "tags": [] }, "source": [ "Combined prompt encoder and mask decoder model has following list of inputs:\n", "\n", "- `image_embeddings`: The image embedding from `image_encoder`. Has a batch index of length 1.\n", "- `point_coords`: Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. _Coordinates must already be transformed to long-side 1024._ Has a batch index of length 1.\n", "- `point_labels`: Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. \\*If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated.\n", "\n", "Model outputs:\n", "\n", "- `masks` - predicted masks resized to original image size, to obtain a binary mask, should be compared with `threshold` (usually equal 0.0).\n", "- `iou_predictions` - intersection over union predictions\n", "- `low_res_masks` - predicted masks before postprocessing, can be used as mask input for model.\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "7da638ba", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "class SamONNXModel(torch.nn.Module):\n", " def __init__(\n", " self,\n", " model,\n", " return_single_mask: bool,\n", " use_stability_score: bool = False,\n", " return_extra_metrics: bool = False,\n", " ) -> None:\n", " super().__init__()\n", " self.mask_decoder = model.mask_decoder\n", " self.model = model\n", " self.img_size = model.image_encoder.img_size\n", " self.return_single_mask = return_single_mask\n", " self.use_stability_score = use_stability_score\n", " self.stability_score_offset = 1.0\n", " self.return_extra_metrics = return_extra_metrics\n", "\n", " def _embed_points(\n", " self, point_coords: torch.Tensor, point_labels: torch.Tensor\n", " ) -> torch.Tensor:\n", " point_coords = point_coords + 0.5\n", " point_coords = point_coords / self.img_size\n", " point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)\n", " point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)\n", "\n", " point_embedding = point_embedding * (point_labels != -1)\n", " point_embedding = (\n", " point_embedding\n", " + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)\n", " )\n", "\n", " for i in range(self.model.prompt_encoder.num_point_embeddings):\n", " point_embedding = (\n", " point_embedding\n", " + self.model.prompt_encoder.point_embeddings[i].weight\n", " * (point_labels == i)\n", " )\n", "\n", " return point_embedding\n", "\n", " def t_embed_masks(self, input_mask: torch.Tensor) -> torch.Tensor:\n", " mask_embedding = self.model.prompt_encoder.mask_downscaling(input_mask)\n", " return mask_embedding\n", "\n", " def mask_postprocessing(self, masks: torch.Tensor) -> torch.Tensor:\n", " masks = torch.nn.functional.interpolate(\n", " masks,\n", " size=(self.img_size, self.img_size),\n", " mode=\"bilinear\",\n", " align_corners=False,\n", " )\n", " return masks\n", "\n", " def select_masks(\n", " self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int\n", " ) -> Tuple[torch.Tensor, torch.Tensor]:\n", " # Determine if we should return the multiclick mask or not from the number of points.\n", " # The reweighting is used to avoid control flow.\n", " score_reweight = torch.tensor(\n", " [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]\n", " ).to(iou_preds.device)\n", " score = iou_preds + (num_points - 2.5) * score_reweight\n", " best_idx = torch.argmax(score, dim=1)\n", " masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)\n", " iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)\n", "\n", " return masks, iou_preds\n", "\n", " @torch.no_grad()\n", " def forward(\n", " self,\n", " image_embeddings: torch.Tensor,\n", " point_coords: torch.Tensor,\n", " point_labels: torch.Tensor,\n", " mask_input: torch.Tensor = None,\n", " ):\n", " sparse_embedding = self._embed_points(point_coords, point_labels)\n", " if mask_input is None:\n", " dense_embedding = self.model.prompt_encoder.no_mask_embed.weight.reshape(\n", " 1, -1, 1, 1\n", " ).expand(point_coords.shape[0], -1, image_embeddings.shape[0], 64)\n", " else:\n", " dense_embedding = self._embed_masks(mask_input)\n", "\n", " masks, scores = self.model.mask_decoder.predict_masks(\n", " image_embeddings=image_embeddings,\n", " image_pe=self.model.prompt_encoder.get_dense_pe(),\n", " sparse_prompt_embeddings=sparse_embedding,\n", " dense_prompt_embeddings=dense_embedding,\n", " )\n", "\n", " if self.use_stability_score:\n", " scores = calculate_stability_score(\n", " masks, self.model.mask_threshold, self.stability_score_offset\n", " )\n", "\n", " if self.return_single_mask:\n", " masks, scores = self.select_masks(masks, scores, point_coords.shape[1])\n", "\n", " upscaled_masks = self.mask_postprocessing(masks)\n", "\n", " if self.return_extra_metrics:\n", " stability_scores = calculate_stability_score(\n", " upscaled_masks, self.model.mask_threshold, self.stability_score_offset\n", " )\n", " areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)\n", " return upscaled_masks, scores, stability_scores, areas, masks\n", "\n", " return upscaled_masks, scores\n", "\n", "\n", "ov_model_path = Path(\"sam_mask_predictor.xml\")\n", "if not ov_model_path.exists():\n", " onnx_model_path = ov_model_path.with_suffix(\".onnx\")\n", " if not onnx_model_path.exists():\n", " onnx_model = SamONNXModel(sam, return_single_mask=True)\n", " dynamic_axes = {\n", " \"point_coords\": {0: \"batch_size\", 1: \"num_points\"},\n", " \"point_labels\": {0: \"batch_size\", 1: \"num_points\"},\n", " }\n", "\n", " embed_dim = sam.prompt_encoder.embed_dim\n", " embed_size = sam.prompt_encoder.image_embedding_size\n", " dummy_inputs = {\n", " \"image_embeddings\": torch.randn(\n", " 1, embed_dim, *embed_size, dtype=torch.float\n", " ),\n", " \"point_coords\": torch.randint(\n", " low=0, high=1024, size=(1, 5, 2), dtype=torch.float\n", " ),\n", " \"point_labels\": torch.randint(\n", " low=0, high=4, size=(1, 5), dtype=torch.float\n", " ),\n", " }\n", " output_names = [\"masks\", \"iou_predictions\"]\n", "\n", " with warnings.catch_warnings():\n", " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", " torch.onnx.export(\n", " onnx_model,\n", " tuple(dummy_inputs.values()),\n", " onnx_model_path,\n", " input_names=list(dummy_inputs.keys()),\n", " output_names=output_names,\n", " dynamic_axes=dynamic_axes,\n", " )\n", "\n", " ov_model = mo.convert_model(onnx_model_path, compress_to_fp16=True)\n", " serialize(ov_model, str(ov_model_path))\n", "else:\n", " ov_model = core.read_model(ov_model_path)\n", "ov_predictor = core.compile_model(ov_model)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "927a928b", "metadata": {}, "source": [ "## Preparing an image for OpenVINO SAM segmentation\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f7406a84", "metadata": {}, "source": [ "Before we can use the OpenVINO model, we need to convert the input data to the correct format. Using an example image, this section details\n", "\n", "- Downloading the image\n", "- Loading the image into FiftyOne for visualization\n", "- Preprocessing for OpenVINO\n", "- Image encoding\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9122811a", "metadata": {}, "source": [ "### Download an image\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ebfc582d", "metadata": {}, "source": [ "To start, we'll use the same image as Facebook Research uses in their [Segment Anything Model demo notebooks](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb):\n" ] }, { "cell_type": "code", "execution_count": null, "id": "90ebe179-46d1-450c-bd81-f838d992c7bc", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 6, "id": "6be6eb55", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import cv2\n", "import fiftyone as fo\n", "\n", "download_file(\n", " \"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg\"\n", ")\n", "\n", "filepath = \"truck.jpg\"\n", "\n", "image = cv2.imread(filepath)\n", "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d034d1e3", "metadata": {}, "source": [ "### Load the image into FiftyOne\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "69b059e6", "metadata": {}, "source": [ "Create a [FiftyOne](https://docs.voxel51.com/) `Dataset` containing this image:\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "7827899b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Computing metadata...\n", " 100% |█████████████████████| 1/1 [788.9ms elapsed, 0s remaining, 1.3 samples/s] \n" ] } ], "source": [ "dataset = fo.Dataset(name=\"openvino_sam\", persistent=True, overwrite=True)\n", "dataset.add_sample(fo.Sample(filepath=filepath))\n", "dataset.compute_metadata()\n", "sample = dataset.first()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5f7c1baf", "metadata": {}, "source": [ "Above, we are using `overwrite=True` to overwrite the existing dataset, which will allow you to run this cell multiple times without throwing an error. Alternatively, if you don't want to overwrite any datasets, you can create a new dataset without a name via `dataset = fo.Dataset(persistent=True)`.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e4299803", "metadata": {}, "source": [ "Once we've created the dataset, we can visualize it in the [FiftyOne App](https://docs.voxel51.com/user_guide/app.html):\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "b7e9a27a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4e635e5c", "metadata": {}, "source": [ "![initial image](https://user-images.githubusercontent.com/12500356/237286219-383eb29f-5e0a-4444-a712-a57d87f42303.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2c0fe872", "metadata": {}, "source": [ "We will also define a few utility functions to help us convert between FiftyOne and OpenVINO SAM formats. `abs_to_rel()` and `rel_to_abs()` convert between absolute and relative pixel coordinates; `convert_sam_to_fo_box()` and `convert_fo_to_sam_box()` converts a box output by OpenVINO SAM to a format that can be used by FiftyOne, and vice versa; and `convert_label()` converts the integer labels output by OpenVINO SAM to strings that can be used by FiftyOne:\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "fe0f6c84", "metadata": { "tags": [] }, "outputs": [], "source": [ "def abs_to_rel(abs_coords, sample):\n", " rel_coords = np.copy(abs_coords).astype(\"float\")\n", " rel_coords[:, 0] /= sample.metadata.width\n", " rel_coords[:, 1] /= sample.metadata.height\n", " return rel_coords\n", "\n", "\n", "def rel_to_abs(rel_coords, sample):\n", " abs_coords = np.copy(rel_coords)\n", " abs_coords[:, 0] *= sample.metadata.width\n", " abs_coords[:, 1] *= sample.metadata.height\n", " return abs_coords.astype(\"int\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "51b7b824", "metadata": { "tags": [] }, "outputs": [], "source": [ "def convert_sam_to_fo_box(box, sample):\n", " ## convert bounding box from SAM format with absolute coordinates\n", " ## [, , , ]\n", " ## to FiftyOne bounding box format with relative coordinates\n", " ## [, , , ]\n", "\n", " w, h = sample.metadata.width, sample.metadata.height\n", " fo_box = np.copy(box).astype(\"float\")\n", " fo_box[0] /= w\n", " fo_box[2] /= w\n", " fo_box[1] /= h\n", " fo_box[3] /= h\n", " fo_box[2] -= fo_box[0]\n", " fo_box[3] -= fo_box[1]\n", " return fo_box\n", "\n", "\n", "def convert_fo_to_sam_box(box, sample):\n", " ## convert bounding box format from FiftyOne with relative coordinates\n", " ## [, , , ]\n", " ## to SAM absolute coordinates\n", " ## [, , , ]\n", "\n", " w, h = sample.metadata.width, sample.metadata.height\n", " sam_box = np.copy(box)\n", " sam_box[0] *= w\n", " sam_box[2] *= w\n", " sam_box[1] *= h\n", " sam_box[3] *= h\n", " sam_box[2] += sam_box[0]\n", " sam_box[3] += sam_box[1]\n", " return sam_box.astype(\"int\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "ded8703f", "metadata": { "tags": [] }, "outputs": [], "source": [ "def convert_label(input_label):\n", " return str(int(input_label))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9c8545e4-3031-4c18-997c-59c732da62fc", "metadata": {}, "source": [ "### Preprocessing utilities\n", "\n", "To prepare iinput for Image Encoder we should:\n", "\n", "1. Convert BGR image to RGB\n", "2. Resize image saving aspect ratio where longest size equal to Image Encoder input size - 1024.\n", "3. Normalize image subtract mean values (123.675, 116.28, 103.53) and divide by std (58.395, 57.12, 57.375)\n", "4. transpose HWC data layout to CHW and add batch dimension.\n", "5. add zero padding to input tensor by height or width (depends on aspect ratio) according Image Encoder expected input shape.\n", "\n", "These steps are applicable to all available models\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "f11af7e5-c2d1-44d3-9dbf-dbb19857fd3f", "metadata": { "tags": [] }, "outputs": [], "source": [ "from copy import deepcopy\n", "from typing import Tuple\n", "from torchvision.transforms.functional import resize, to_pil_image\n", "\n", "\n", "class ResizeLongestSide:\n", " \"\"\"\n", " Resizes images to longest side 'target_length', as well as provides\n", " methods for resizing coordinates and boxes. Provides methods for\n", " transforming numpy arrays.\n", " \"\"\"\n", "\n", " def __init__(self, target_length: int) -> None:\n", " self.target_length = target_length\n", "\n", " def apply_image(self, image: np.ndarray) -> np.ndarray:\n", " \"\"\"\n", " Expects a numpy array with shape HxWxC in uint8 format.\n", " \"\"\"\n", " target_size = self.get_preprocess_shape(\n", " image.shape[0], image.shape[1], self.target_length\n", " )\n", " return np.array(resize(to_pil_image(image), target_size))\n", "\n", " def apply_coords(\n", " self, coords: np.ndarray, original_size: Tuple[int, ...]\n", " ) -> np.ndarray:\n", " \"\"\"\n", " Expects a numpy array of length 2 in the final dimension. Requires the\n", " original image size in (H, W) format.\n", " \"\"\"\n", " old_h, old_w = original_size\n", " new_h, new_w = self.get_preprocess_shape(\n", " original_size[0], original_size[1], self.target_length\n", " )\n", " coords = deepcopy(coords).astype(float)\n", " coords[..., 0] = coords[..., 0] * (new_w / old_w)\n", " coords[..., 1] = coords[..., 1] * (new_h / old_h)\n", " return coords\n", "\n", " def apply_boxes(\n", " self, boxes: np.ndarray, original_size: Tuple[int, ...]\n", " ) -> np.ndarray:\n", " \"\"\"\n", " Expects a numpy array shape Bx4. Requires the original image size\n", " in (H, W) format.\n", " \"\"\"\n", " boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)\n", " return boxes.reshape(-1, 4)\n", "\n", " @staticmethod\n", " def get_preprocess_shape(\n", " oldh: int, oldw: int, long_side_length: int\n", " ) -> Tuple[int, int]:\n", " \"\"\"\n", " Compute the output size given input size and target long side length.\n", " \"\"\"\n", " scale = long_side_length * 1.0 / max(oldh, oldw)\n", " newh, neww = oldh * scale, oldw * scale\n", " neww = int(neww + 0.5)\n", " newh = int(newh + 0.5)\n", " return (newh, neww)\n", "\n", "\n", "resizer = ResizeLongestSide(1024)\n", "\n", "\n", "def preprocess_image(image: np.ndarray):\n", " resized_image = resizer.apply_image(image)\n", " resized_image = (resized_image.astype(np.float32) - [123.675, 116.28, 103.53]) / [\n", " 58.395,\n", " 57.12,\n", " 57.375,\n", " ]\n", " resized_image = np.expand_dims(\n", " np.transpose(resized_image, (2, 0, 1)).astype(np.float32), 0\n", " )\n", "\n", " # Pad\n", " h, w = resized_image.shape[-2:]\n", " padh = 1024 - h\n", " padw = 1024 - w\n", " x = np.pad(resized_image, ((0, 0), (0, 0), (0, padh), (0, padw)))\n", " return x\n", "\n", "\n", "def postprocess_masks(masks: np.ndarray, orig_size):\n", " size_before_pad = resizer.get_preprocess_shape(\n", " orig_size[0], orig_size[1], masks.shape[-1]\n", " )\n", " masks = masks[..., : int(size_before_pad[0]), : int(size_before_pad[1])]\n", " masks = torch.nn.functional.interpolate(\n", " torch.from_numpy(masks), size=orig_size, mode=\"bilinear\", align_corners=False\n", " ).numpy()\n", " return masks" ] }, { "attachments": {}, "cell_type": "markdown", "id": "56e1a79b-c54b-4201-b8f9-c85f92473f00", "metadata": {}, "source": [ "### Image encoding\n", "\n", "To start work with image, we should preprocess it and obtain image embeddings using `ov_encoder`. We will use the same image for all experiments, so it is possible to generate image embedding once and then reuse them.\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "5ffe75e8-5751-4338-a7d5-04274a46d7a4", "metadata": { "tags": [] }, "outputs": [], "source": [ "preprocessed_image = preprocess_image(image)\n", "encoding_results = ov_encoder(preprocessed_image)\n", "\n", "image_embeddings = encoding_results[ov_encoder.output(0)]" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8d0bc25b", "metadata": {}, "source": [ "Save the image embeddings on the sample:\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "bfca0e6c", "metadata": { "tags": [] }, "outputs": [], "source": [ "sample[\"image_embeddings\"] = image_embeddings\n", "sample.save()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c0ee1617", "metadata": {}, "source": [ "Next, we define a function which will take in the sample (with the image embeddings) and a prompt, and return the predicted mask:\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "52636471", "metadata": { "tags": [] }, "outputs": [], "source": [ "def generate_mask(sample, point_coords, point_labels, box_coords=None, box_labels=None):\n", " image = cv2.imread(sample.filepath)\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " image_embeddings = sample[\"image_embeddings\"]\n", " \n", " if box_coords is None:\n", " box_coords = np.array([[0.0, 0.0]])\n", " box_labels = np.array([-1])\n", " else:\n", " box_coords = box_coords.reshape(2, 2)\n", " \n", " if point_coords is None or len(point_coords) == 0:\n", " point_coords = np.array([[0.0, 0.0]])\n", " point_labels = np.array([0])\n", "\n", " coords = np.concatenate([point_coords, box_coords], axis=0)[None, :, :]\n", " labels = np.concatenate([point_labels, box_labels], axis=0)[None, :].astype(\n", " np.float32\n", " )\n", " coords = resizer.apply_coords(coords, image.shape[:2]).astype(np.float32)\n", "\n", " inputs = {\n", " \"image_embeddings\": image_embeddings,\n", " \"point_coords\": coords,\n", " \"point_labels\": labels,\n", " }\n", "\n", " results = ov_predictor(inputs)\n", "\n", " masks = results[ov_predictor.output(0)]\n", " masks = postprocess_masks(masks, image.shape[:-1])\n", " masks = masks > 0.0\n", " return masks[0, 0, :, :].astype(np.uint8)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cdd117b2", "metadata": {}, "source": [ "This function adds a batch index, concatenates a padding point, and transforms it to input tensor coordinate system. It then packages the inputs to run in the mask predictor. Finally, it predicts a mask and thresholds it to get binary mask (0 - no object, 1 - object).\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1ec0778a", "metadata": {}, "source": [ "## Run OpenVINO SAM with prompts\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e8dcc1c3", "metadata": {}, "source": [ "Now we are ready to run OpenVINO SAM with various prompts for mask generation.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "12207674", "metadata": {}, "source": [ "### Run OpenVINO SAM with single-point prompt\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "bf5a9f55", "metadata": {}, "source": [ "First, we select a single point. Starting from the point's location in pixels,\n", "we convert it to relative coordinates and then define a `Keypoint` label in FiftyOne for this point:\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "1c0deef0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "input_point = np.array([[500, 375]])\n", "input_label = np.array([1])\n", "\n", "sample[\"window_point\"] = fo.Keypoint(\n", " label=convert_label(input_label[0]), points=abs_to_rel(input_point, sample)\n", ")\n", "sample.save()\n", "\n", "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e1dd4e2f", "metadata": {}, "source": [ "![window point](https://user-images.githubusercontent.com/12500356/237286223-f0529d1d-06b5-46ee-ae5e-a7e85f08c40b.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "53d6a9c3", "metadata": {}, "source": [ "Hovering over the point in the FiftyOne App, we can see the name of the `Keypoint` label we created, \"window_point\", as well as its label: `\"1\"`.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "badb1175", "metadata": {}, "source": [ "Now it is easy to generate a mask for this point:\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "9417fad0", "metadata": { "tags": [] }, "outputs": [], "source": [ "window_mask = generate_mask(sample, input_point, input_label)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ddde9d35", "metadata": {}, "source": [ "And then we can add the mask to the sample and visualize it in the FiftyOne App:\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "f3edff57", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "sample[\"window_mask\"] = fo.Segmentation(mask=window_mask)\n", "sample.save()\n", "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ecd58f9c", "metadata": {}, "source": [ "![window mask](https://user-images.githubusercontent.com/12500356/237286227-31c1f9f6-0e87-4d9f-b625-a19e84af9991.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ad99b3cb", "metadata": {}, "source": [ "### Run OpenVINO SAM with multi-point prompt\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1f1d4d15", "metadata": {}, "source": [ "Now let's provide additional points covering a larger object area.\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "b319da82", "metadata": { "tags": [] }, "outputs": [], "source": [ "input_point = np.array([[500, 375], [1125, 625], [575, 750], [1405, 575]])\n", "input_label = np.array([1, 1, 1, 1])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9e3c49cb-beda-4d3f-8706-12b196640448", "metadata": {}, "source": [ "To see what this prompt for model looks like on this image, we can represent these points with a FiftyOne `Keypoints` label:\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "2d4f9fdd", "metadata": { "tags": [] }, "outputs": [], "source": [ "input_point_fo = abs_to_rel(input_point, sample)\n", "\n", "sample[\"car_points\"] = fo.Keypoints(\n", " keypoints=[\n", " fo.Keypoint(label=convert_label(il), points=[tuple(ip)])\n", " for il, ip in zip(input_label, input_point_fo)\n", " ]\n", ")\n", "sample.save()" ] }, { "cell_type": "code", "execution_count": 21, "id": "78103427-f536-432c-9838-5710068ccaae", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c27f2c45", "metadata": {}, "source": [ "![car points](https://user-images.githubusercontent.com/12500356/237286209-4aace1dc-7bf3-47d6-8633-c91c8123e639.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b1823b37", "metadata": {}, "source": [ "We can generate a mask as in the previous example:\n" ] }, { "cell_type": "code", "execution_count": 22, "id": "8885130f", "metadata": { "tags": [] }, "outputs": [], "source": [ "car_mask = generate_mask(sample, input_point, input_label)\n", "sample[\"car_mask\"] = fo.Segmentation(mask=car_mask)\n", "sample.save()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d3781955", "metadata": {}, "source": [ "And then once again visualize the mask in the FiftyOne App:\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "0c1ec096", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1e36554b", "metadata": {}, "source": [ "![car mask](https://user-images.githubusercontent.com/12500356/237286225-4e4a4e68-25eb-46f0-bb1a-0a3b82501226.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b23bdc35-6dcf-441d-8cf0-614977a7c262", "metadata": {}, "source": [ "Great! Looks like now, the predicted mask cover whole truck.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2ef211d0", "metadata": {}, "source": [ "### Run OpenVINO SAM with a box and negative point label\n", "\n", "In this final prompting example, we define input prompt using bounding box and point inside it. The bounding box represented as set of points of its left upper corner and right lower corner. Label `0` for a point means that this point should be excluded from mask.\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "51e58d2e", "metadata": { "tags": [] }, "outputs": [], "source": [ "input_point = np.array([[575, 750]])\n", "input_label = np.array([0])\n", "\n", "input_box = np.array([425, 600, 700, 875])\n", "box_labels = np.array([2, 3])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9bde6cb1", "metadata": {}, "source": [ "We will represent the point as a `Keypoint` label, and the box as a `Detections` label:\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "b571fbc7-424a-4013-8c37-81f0da7647a5", "metadata": { "tags": [] }, "outputs": [], "source": [ "sample[\"tire_point\"] = fo.Keypoint(\n", " label=convert_label(input_label[0]), points=abs_to_rel(input_point, sample)\n", ")\n", "sample.save()" ] }, { "cell_type": "code", "execution_count": 26, "id": "8a366af7", "metadata": { "tags": [] }, "outputs": [], "source": [ "sample[\"tire_box\"] = fo.Detection(\n", " label=\"tire\",\n", " bounding_box=convert_sam_to_fo_box(input_box, sample),\n", ")\n", "sample.save()" ] }, { "cell_type": "code", "execution_count": 27, "id": "e8b8933f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "01cf49c5", "metadata": {}, "source": [ "![tire prompt](https://user-images.githubusercontent.com/12500356/237286233-f3321f3e-320d-4acf-a570-de999a4c5c00.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6e119dcb", "metadata": {}, "source": [ "This time, we pass in `box_coords` and `box_labels` arguments to `generate_mask()`. There is no padding point since the input includes a box input.\n" ] }, { "cell_type": "code", "execution_count": 28, "id": "bfbe4911", "metadata": { "tags": [] }, "outputs": [], "source": [ "tire_mask = generate_mask(\n", " sample, input_point, input_label, box_coords=input_box, box_labels=box_labels\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "65edabd2", "metadata": {}, "source": [ "We can visualize the mask in the FiftyOne App with a `Detection` label that has a non-trivial `mask` field:\n" ] }, { "cell_type": "code", "execution_count": 29, "id": "2abfba56", "metadata": { "tags": [] }, "outputs": [], "source": [ "x0, y0, x1, y1 = input_box\n", "mask_trimmed = np.array(tire_mask[y0 : y1 + 1, x0 : x1 + 1])\n", "\n", "sample[\"tire_mask\"] = fo.Detection(\n", " label=\"tire\",\n", " bounding_box=convert_sam_to_fo_box(input_box, sample),\n", " mask=mask_trimmed,\n", ")\n", "\n", "sample.save()" ] }, { "cell_type": "code", "execution_count": 30, "id": "b3abb676", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c65e2580", "metadata": {}, "source": [ "![tire mask](https://user-images.githubusercontent.com/12500356/237286231-c0b0b65f-cadb-4626-a051-eea1b801e469.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c4120904", "metadata": {}, "source": [ "## Run OpenVINO SAM in automatic mask generation mode\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a13e5bdd-49fb-4720-80fa-e0fc8c40c898", "metadata": {}, "source": [ "Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image.\n", "The `automatic_mask_generation()` function implements this capability.\n", "\n", "It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.\n" ] }, { "cell_type": "code", "execution_count": 31, "id": "a2612d32-8ed2-4e35-abb3-7b51c2108d79", "metadata": { "tags": [] }, "outputs": [], "source": [ "from segment_anything.utils.amg import (\n", " MaskData,\n", " generate_crop_boxes,\n", " uncrop_boxes_xyxy,\n", " uncrop_masks,\n", " uncrop_points,\n", " calculate_stability_score,\n", " rle_to_mask,\n", " batched_mask_to_box,\n", " mask_to_rle_pytorch,\n", " is_box_near_crop_edge,\n", " batch_iterator,\n", " remove_small_regions,\n", " build_all_layer_point_grids,\n", " box_xyxy_to_xywh,\n", " area_from_rle,\n", ")\n", "from torchvision.ops.boxes import batched_nms, box_area\n", "from typing import Tuple, List, Dict, Any" ] }, { "cell_type": "code", "execution_count": 32, "id": "901914c4", "metadata": { "tags": [] }, "outputs": [], "source": [ "def process_batch(\n", " image_embedding: np.ndarray,\n", " points: np.ndarray,\n", " im_size: Tuple[int, ...],\n", " crop_box: List[int],\n", " orig_size: Tuple[int, ...],\n", " iou_thresh,\n", " mask_threshold,\n", " stability_score_offset,\n", " stability_score_thresh,\n", ") -> MaskData:\n", " orig_h, orig_w = orig_size\n", "\n", " # Run model on this batch\n", " transformed_points = resizer.apply_coords(points, im_size)\n", " in_points = transformed_points\n", " in_labels = np.ones(in_points.shape[0], dtype=int)\n", "\n", " inputs = {\n", " \"image_embeddings\": image_embedding,\n", " \"point_coords\": in_points[:, None, :],\n", " \"point_labels\": in_labels[:, None],\n", " }\n", " res = ov_predictor(inputs)\n", " masks = postprocess_masks(res[ov_predictor.output(0)], orig_size)\n", " masks = torch.from_numpy(masks)\n", " iou_preds = torch.from_numpy(res[ov_predictor.output(1)])\n", "\n", " # Serialize predictions and store in MaskData\n", " data = MaskData(\n", " masks=masks.flatten(0, 1),\n", " iou_preds=iou_preds.flatten(0, 1),\n", " points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),\n", " )\n", " del masks\n", "\n", " # Filter by predicted IoU\n", " if iou_thresh > 0.0:\n", " keep_mask = data[\"iou_preds\"] > iou_thresh\n", " data.filter(keep_mask)\n", "\n", " # Calculate stability score\n", " data[\"stability_score\"] = calculate_stability_score(\n", " data[\"masks\"], mask_threshold, stability_score_offset\n", " )\n", " if stability_score_thresh > 0.0:\n", " keep_mask = data[\"stability_score\"] >= stability_score_thresh\n", " data.filter(keep_mask)\n", "\n", " # Threshold masks and calculate boxes\n", " data[\"masks\"] = data[\"masks\"] > mask_threshold\n", " data[\"boxes\"] = batched_mask_to_box(data[\"masks\"])\n", "\n", " # Filter boxes that touch crop boundaries\n", " keep_mask = ~is_box_near_crop_edge(data[\"boxes\"], crop_box, [0, 0, orig_w, orig_h])\n", " if not torch.all(keep_mask):\n", " data.filter(keep_mask)\n", "\n", " # Compress to RLE\n", " data[\"masks\"] = uncrop_masks(data[\"masks\"], crop_box, orig_h, orig_w)\n", " data[\"rles\"] = mask_to_rle_pytorch(data[\"masks\"])\n", " del data[\"masks\"]\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 33, "id": "69c9c706", "metadata": { "tags": [] }, "outputs": [], "source": [ "def process_crop(\n", " image: np.ndarray,\n", " point_grids,\n", " crop_box: List[int],\n", " crop_layer_idx: int,\n", " orig_size: Tuple[int, ...],\n", " box_nms_thresh: float = 0.7,\n", " mask_threshold: float = 0.0,\n", " points_per_batch: int = 64,\n", " pred_iou_thresh: float = 0.88,\n", " stability_score_thresh: float = 0.95,\n", " stability_score_offset: float = 1.0,\n", ") -> MaskData:\n", " # Crop the image and calculate embeddings\n", " x0, y0, x1, y1 = crop_box\n", " cropped_im = image[y0:y1, x0:x1, :]\n", " cropped_im_size = cropped_im.shape[:2]\n", " preprocessed_cropped_im = preprocess_image(cropped_im)\n", " crop_embeddings = ov_encoder(preprocessed_cropped_im)[ov_encoder.output(0)]\n", "\n", " # Get points for this crop\n", " points_scale = np.array(cropped_im_size)[None, ::-1]\n", " points_for_image = point_grids[crop_layer_idx] * points_scale\n", "\n", " # Generate masks for this crop in batches\n", " data = MaskData()\n", " for (points,) in batch_iterator(points_per_batch, points_for_image):\n", " batch_data = process_batch(\n", " crop_embeddings,\n", " points,\n", " cropped_im_size,\n", " crop_box,\n", " orig_size,\n", " pred_iou_thresh,\n", " mask_threshold,\n", " stability_score_offset,\n", " stability_score_thresh,\n", " )\n", " data.cat(batch_data)\n", " del batch_data\n", "\n", " # Remove duplicates within this crop.\n", " keep_by_nms = batched_nms(\n", " data[\"boxes\"].float(),\n", " data[\"iou_preds\"],\n", " torch.zeros(len(data[\"boxes\"])), # categories\n", " iou_threshold=box_nms_thresh,\n", " )\n", " data.filter(keep_by_nms)\n", "\n", " # Return to the original image frame\n", " data[\"boxes\"] = uncrop_boxes_xyxy(data[\"boxes\"], crop_box)\n", " data[\"points\"] = uncrop_points(data[\"points\"], crop_box)\n", " data[\"crop_boxes\"] = torch.tensor([crop_box for _ in range(len(data[\"rles\"]))])\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 34, "id": "45110a6e-4c90-477f-9eeb-840ee04fcd73", "metadata": { "tags": [] }, "outputs": [], "source": [ "def generate_masks(\n", " image: np.ndarray, point_grids, crop_n_layers, crop_overlap_ratio, crop_nms_thresh\n", ") -> MaskData:\n", " orig_size = image.shape[:2]\n", " crop_boxes, layer_idxs = generate_crop_boxes(\n", " orig_size, crop_n_layers, crop_overlap_ratio\n", " )\n", "\n", " # Iterate over image crops\n", " data = MaskData()\n", " for crop_box, layer_idx in zip(crop_boxes, layer_idxs):\n", " crop_data = process_crop(image, point_grids, crop_box, layer_idx, orig_size)\n", " data.cat(crop_data)\n", "\n", " # Remove duplicate masks between crops\n", " if len(crop_boxes) > 1:\n", " # Prefer masks from smaller crops\n", " scores = 1 / box_area(data[\"crop_boxes\"])\n", " scores = scores.to(data[\"boxes\"].device)\n", " keep_by_nms = batched_nms(\n", " data[\"boxes\"].float(),\n", " scores,\n", " torch.zeros(len(data[\"boxes\"])), # categories\n", " iou_threshold=crop_nms_thresh,\n", " )\n", " data.filter(keep_by_nms)\n", "\n", " data.to_numpy()\n", " return data" ] }, { "cell_type": "code", "execution_count": 35, "id": "4c808299-5114-4178-ad82-852b9e1307f3", "metadata": { "tags": [] }, "outputs": [], "source": [ "def postprocess_small_regions(\n", " mask_data: MaskData, min_area: int, nms_thresh: float\n", ") -> MaskData:\n", " \"\"\"\n", " Removes small disconnected regions and holes in masks, then reruns\n", " box NMS to remove any new duplicates.\n", "\n", " Edits mask_data in place.\n", "\n", " Requires open-cv as a dependency.\n", " \"\"\"\n", " if len(mask_data[\"rles\"]) == 0:\n", " return mask_data\n", "\n", " # Filter small disconnected regions and holes\n", " new_masks = []\n", " scores = []\n", " for rle in mask_data[\"rles\"]:\n", " mask = rle_to_mask(rle)\n", "\n", " mask, changed = remove_small_regions(mask, min_area, mode=\"holes\")\n", " unchanged = not changed\n", " mask, changed = remove_small_regions(mask, min_area, mode=\"islands\")\n", " unchanged = unchanged and not changed\n", "\n", " new_masks.append(torch.as_tensor(mask).unsqueeze(0))\n", " # Give score=0 to changed masks and score=1 to unchanged masks\n", " # so NMS will prefer ones that didn't need postprocessing\n", " scores.append(float(unchanged))\n", "\n", " # Recalculate boxes and remove any new duplicates\n", " masks = torch.cat(new_masks, dim=0)\n", " boxes = batched_mask_to_box(masks)\n", " keep_by_nms = batched_nms(\n", " boxes.float(),\n", " torch.as_tensor(scores),\n", " torch.zeros(len(boxes)), # categories\n", " iou_threshold=nms_thresh,\n", " )\n", "\n", " # Only recalculate RLEs for masks that have changed\n", " for i_mask in keep_by_nms:\n", " if scores[i_mask] == 0.0:\n", " mask_torch = masks[i_mask].unsqueeze(0)\n", " mask_data[\"rles\"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]\n", " # update res directly\n", " mask_data[\"boxes\"][i_mask] = boxes[i_mask]\n", " mask_data.filter(keep_by_nms)\n", "\n", " return mask_data" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1ef507ce-0ed2-4ed0-8ccb-3c0e2368621f", "metadata": {}, "source": [ "There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes\n" ] }, { "cell_type": "code", "execution_count": 36, "id": "ad8806f6-08fb-4156-b33b-3a1b4aa40d99", "metadata": { "tags": [] }, "outputs": [], "source": [ "def automatic_mask_generation(\n", " image: np.ndarray,\n", " min_mask_region_area: int = 0,\n", " points_per_side: int = 32,\n", " crop_n_layers: int = 0,\n", " crop_n_points_downscale_factor: int = 1,\n", " crop_overlap_ratio: float = 512 / 1500,\n", " box_nms_thresh: float = 0.7,\n", " crop_nms_thresh: float = 0.7,\n", ") -> List[Dict[str, Any]]:\n", " \"\"\"\n", " Generates masks for the given image.\n", "\n", " Arguments:\n", " image (np.ndarray): The image to generate masks for, in HWC uint8 format.\n", "\n", " Returns:\n", " list(dict(str, any)): A list over records for masks. Each record is\n", " a dict containing the following keys:\n", " segmentation (dict(str, any) or np.ndarray): The mask. If\n", " output_mode='binary_mask', is an array of shape HW. Otherwise,\n", " is a dictionary containing the RLE.\n", " bbox (list(float)): The box around the mask, in XYWH format.\n", " area (int): The area in pixels of the mask.\n", " predicted_iou (float): The model's own prediction of the mask's\n", " quality. This is filtered by the pred_iou_thresh parameter.\n", " point_coords (list(list(float))): The point coordinates input\n", " to the model to generate this mask.\n", " stability_score (float): A measure of the mask's quality. This\n", " is filtered on using the stability_score_thresh parameter.\n", " crop_box (list(float)): The crop of the image used to generate\n", " the mask, given in XYWH format.\n", " \"\"\"\n", " point_grids = build_all_layer_point_grids(\n", " points_per_side,\n", " crop_n_layers,\n", " crop_n_points_downscale_factor,\n", " )\n", " mask_data = generate_masks(\n", " image, point_grids, crop_n_layers, crop_overlap_ratio, crop_nms_thresh\n", " )\n", "\n", " # Filter small disconnected regions and holes in masks\n", " if min_mask_region_area > 0:\n", " mask_data = postprocess_small_regions(\n", " mask_data,\n", " min_mask_region_area,\n", " max(box_nms_thresh, crop_nms_thresh),\n", " )\n", "\n", " mask_data[\"segmentations\"] = [rle_to_mask(rle) for rle in mask_data[\"rles\"]]\n", "\n", " # Write mask records\n", " curr_anns = []\n", " for idx in range(len(mask_data[\"segmentations\"])):\n", " ann = {\n", " \"segmentation\": mask_data[\"segmentations\"][idx],\n", " \"area\": area_from_rle(mask_data[\"rles\"][idx]),\n", " \"bbox\": box_xyxy_to_xywh(mask_data[\"boxes\"][idx]).tolist(),\n", " \"predicted_iou\": mask_data[\"iou_preds\"][idx].item(),\n", " \"point_coords\": [mask_data[\"points\"][idx].tolist()],\n", " \"stability_score\": mask_data[\"stability_score\"][idx].item(),\n", " \"crop_box\": box_xyxy_to_xywh(mask_data[\"crop_boxes\"][idx]).tolist(),\n", " }\n", " curr_anns.append(ann)\n", "\n", " return curr_anns" ] }, { "cell_type": "code", "execution_count": 37, "id": "7b9f1cb7-d4c3-4528-8c13-e90707141908", "metadata": { "tags": [] }, "outputs": [], "source": [ "prediction = automatic_mask_generation(image)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c2f1ee93-afde-4a38-95ef-7b891fe63a1e", "metadata": { "tags": [] }, "source": [ "`automatic_mask_generation` returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:\n", "\n", "- `segmentation` : the mask\n", "- `area` : the area of the mask in pixels\n", "- `bbox` : the boundary box of the mask in XYWH format\n", "- `predicted_iou` : the model's own prediction for the quality of the mask\n", "- `point_coords` : the sampled input point that generated this mask\n", "- `stability_score` : an additional measure of mask quality\n", "- `crop_box` : the crop of the image used to generate this mask in XYWH format\n" ] }, { "cell_type": "code", "execution_count": 38, "id": "c20b51f3-6e90-486c-8e72-1a5def89cbd8", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of detected masks: 48\n", "Annotation keys: dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])\n" ] } ], "source": [ "print(f\"Number of detected masks: {len(prediction)}\")\n", "print(f\"Annotation keys: {prediction[0].keys()}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d68e7b03-6c42-4e0f-849d-8fdc72a65f14", "metadata": { "tags": [] }, "source": [ "Now we can add these automatically generated masks to the sample and visualize them in the FiftyOne App. We will use a randomly chosen color for each mask:\n" ] }, { "cell_type": "code", "execution_count": 39, "id": "188a1793", "metadata": { "tags": [] }, "outputs": [], "source": [ "def add_auto_mask(auto_mask, col, full_mask):\n", " mask = auto_mask[\"segmentation\"].astype(np.uint8)\n", " bbox = auto_mask[\"bbox\"]\n", " x0, y0, x1, y1 = bbox\n", " mask[y0 : y1 + 1, x0 : x1 + 1] *= col\n", " full_mask += mask\n", " return full_mask" ] }, { "cell_type": "code", "execution_count": 40, "id": "a31f4b84", "metadata": { "tags": [] }, "outputs": [], "source": [ "full_mask = np.zeros(image.shape[:-1])\n", "for am in sorted(prediction, key=(lambda x: x[\"area\"]), reverse=True):\n", " col = np.random.randint(0, 100)\n", " full_mask = add_auto_mask(am, col, full_mask)\n", "sample[\"autoseg\"] = fo.Segmentation(mask=full_mask)\n", "sample.save()" ] }, { "cell_type": "code", "execution_count": 41, "id": "f1c01677", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6097ed47-689e-4226-8f6c-2025433a3cbf", "metadata": { "tags": [] }, "source": [ "![auto mask](https://user-images.githubusercontent.com/12500356/237520529-4c3533b4-6e32-4e80-a381-781fc7b14769.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0a585253", "metadata": {}, "source": [ "## Curate OpenVINO SAM predictions across an entire dataset\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "05f9f0bb", "metadata": {}, "source": [ "With FiftyOne, extending any of these example workflows to a full dataset is easy. Doing so will allow us to visually and programmatically understand the quality of the model's predictions across the entire dataset, or any subset which may be of interest.\n", "\n", "For the sake of brevity (this is already a long notebook!), we will just show how to do this for one type of segmentation: namely, we will generate instance segmentations for object detection bounding boxes in a subset of the MS COCO dataset. However, the same process can be applied to any of the other segmentation types, or to any other dataset.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e7cfac53", "metadata": {}, "source": [ "To start, we will load a subset of the MS COCO dataset. In FiftyOne, we can load the dataset (or a subset thereof) from the FiftyOne Dataset Zoo. For example, we can load 100 samples from the validation split of the MS COCO 2017 dataset with the following command:\n" ] }, { "cell_type": "code", "execution_count": 42, "id": "81fa2866", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading split 'validation' to '/scratch/user/jacob/zoo_datasets/coco-2017/validation' if necessary\n", "Found annotations at '/scratch/user/jacob/zoo_datasets/coco-2017/raw/instances_val2017.json'\n", "Sufficient images already downloaded\n", "Existing download of split 'validation' is sufficient\n", "Loading 'coco-2017' split 'validation'\n", " 100% |█████████████████| 100/100 [405.8ms elapsed, 0s remaining, 248.7 samples/s] \n", "Dataset 'coco-2017-validation-100' created\n" ] } ], "source": [ "import fiftyone.zoo as foz\n", "\n", "dataset = foz.load_zoo_dataset(\n", " \"coco-2017\",\n", " split=\"validation\",\n", " max_samples=100,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "30bb7abf", "metadata": {}, "source": [ "We van visualize the dataset in the FiftyOne App:\n" ] }, { "cell_type": "code", "execution_count": 43, "id": "81b285b1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto=False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "83d39640", "metadata": {}, "source": [ "![coco](https://user-images.githubusercontent.com/12500356/239369373-06f039c0-d429-474a-9d0c-90ec2e06d54b.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "47d508d7", "metadata": {}, "source": [ "We can see that the dataset contains a variety of images with different types of objects, and that the objects are annotated with bounding boxes. The label field is called `ground_truth` and contains a `Detections` label.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "125d2b80", "metadata": {}, "source": [ "When downloading or loading in a new dataset, you also typically need to compute metadata for this dataset as we did for the single-sample example above, so we have easy access to the image width and height in pixels. In this case, the image width and height are already present in the dataset!\n" ] }, { "cell_type": "code", "execution_count": 44, "id": "2eb12c6b", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.first().metadata" ] }, { "attachments": {}, "cell_type": "markdown", "id": "616d50ca", "metadata": {}, "source": [ "We will then iterate through the samples in the dataset, processing the images and adding the image embeddings to each sample:\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "b541529d", "metadata": { "tags": [] }, "outputs": [], "source": [ "def generate_image_embeddings(dataset):\n", " for sample in dataset.iter_samples(autosave=True, progress=True):\n", " image = cv2.imread(sample.filepath)\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", " preprocessed_image = preprocess_image(image)\n", " encoding_results = ov_encoder(preprocessed_image)\n", " image_embeddings = encoding_results[ov_encoder.output(0)]\n", " sample[\"image_embeddings\"] = image_embeddings\n", " " ] }, { "cell_type": "code", "execution_count": 46, "id": "45c8e807-f01b-48da-ad83-9745b315ec39", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset.add_sample_field(\"image_embeddings\", fo.ArrayField)" ] }, { "cell_type": "code", "execution_count": 47, "id": "f3d3c092-4a79-4feb-96bb-8880e0e805a1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ " 100% |█████████████████| 100/100 [3.0m elapsed, 0s remaining, 0.6 samples/s] \n" ] } ], "source": [ "generate_image_embeddings(dataset)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9091e7b4", "metadata": {}, "source": [ "Because we already have the object detection bounding boxes, we can use them as input prompts for SAM. Instead of creating a new label field for the masks, we can just add them to the existing label field in the `mask` field of each `Detection` label. We will rename the `ground_truth` label field to `sam` to reflect this:\n" ] }, { "cell_type": "code", "execution_count": 48, "id": "0d16558d", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset.rename_sample_field(\"ground_truth\", \"sam\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "384e0fe9", "metadata": {}, "source": [ "We will use the output of the `generate_mask()` function we defined earlier to generate a mask for each bounding box in an image. We will then truncate these masks to the regions of the image that are covered by the bounding box, and add them to the `mask` field of the corresponding `Detection` label.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9cffbdaf", "metadata": {}, "source": [ "First, we will define a helper function that will add a mask to a `Detection` label:\n" ] }, { "cell_type": "code", "execution_count": 49, "id": "3d02efd0", "metadata": { "tags": [] }, "outputs": [], "source": [ "def add_SAM_mask_to_detection(detection, mask, sample):\n", " y0, x0, y1, x1 = convert_fo_to_sam_box(detection.bounding_box, sample)\n", " mask_trimmed = mask[x0 : x1 + 1, y0 : y1 + 1]\n", " detection[\"mask\"] = np.array(mask_trimmed)\n", " return detection" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4a6491c2", "metadata": {}, "source": [ "We can then loop over the detections in each sample, generate a mask for each detection, and add it to the `mask` field of the `Detection` label. We will do so with the `add_SAM_instance_masks_to_sample()`. We give the function a `field_name` argument so that we can use it to add masks to any label field in the dataset, but in this case we will use it to add masks to the `sam` label field:\n" ] }, { "cell_type": "code", "execution_count": 50, "id": "4c7a8b23-6058-4d40-912f-6126eda88e20", "metadata": { "tags": [] }, "outputs": [], "source": [ "def add_SAM_instance_masks_to_sample(sample, field_name = \"sam\"):\n", " if sample[field_name] is None:\n", " return\n", "\n", " dets = sample[field_name].detections\n", " image_embedding = sample['image_embeddings']\n", "\n", " input_point = np.array([[0.0, 0.0]])\n", " input_label = np.array([0])\n", " box_labels = np.array([2, 3])\n", " \n", " new_dets = []\n", " for det in dets:\n", " fo_box = det.bounding_box\n", " sam_box = convert_fo_to_sam_box(fo_box, sample)\n", " mask = generate_mask(\n", " sample,\n", " input_point, \n", " input_label, \n", " box_coords=sam_box, \n", " box_labels=box_labels,\n", " )\n", " \n", " new_dets.append(add_SAM_mask_to_detection(det, mask, sample))\n", "\n", " sample[field_name] = fo.Detections(detections=new_dets)" ] }, { "cell_type": "code", "execution_count": 51, "id": "d3f39fe4", "metadata": { "tags": [] }, "outputs": [], "source": [ "def add_SAM_instance_masks_to_sample(sample, field_name=\"sam\"):\n", " if sample[field_name] is None:\n", " return\n", "\n", " dets = sample[field_name].detections\n", "\n", " input_point = None\n", " input_label = None\n", " box_labels = np.array([2, 3])\n", "\n", " new_dets = []\n", " for det in dets:\n", " fo_box = det.bounding_box\n", " sam_box = convert_fo_to_sam_box(fo_box, sample)\n", " mask = generate_mask(\n", " sample, input_point, input_label, box_coords=sam_box, box_labels=box_labels\n", " )\n", "\n", " new_dets.append(add_SAM_mask_to_detection(det, mask, sample))\n", "\n", " sample[field_name] = fo.Detections(detections=new_dets)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ac1430bd", "metadata": {}, "source": [ "Finally, we will loop through all of the samples in our dataset and add the SAM masks to the `sam` label field:\n" ] }, { "cell_type": "code", "execution_count": 52, "id": "1ee01645", "metadata": { "tags": [] }, "outputs": [], "source": [ "def add_SAM_instance_segmentation_masks(dataset, field_name=\"sam\"):\n", " for sample in dataset.iter_samples(autosave=True, progress = True):\n", " add_SAM_instance_masks_to_sample(sample, field_name=field_name)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1ca5c9a0", "metadata": {}, "source": [ "Let's run this routine on our dataset and visualize the results in the FiftyOne App, coloring masks by the class of the object they are associated with:\n" ] }, { "cell_type": "code", "execution_count": 53, "id": "5d04ad9e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ " 100% |█████████████████| 100/100 [58.8s elapsed, 0s remaining, 1.4 samples/s] \n" ] } ], "source": [ "add_SAM_instance_segmentation_masks(dataset)" ] }, { "cell_type": "code", "execution_count": 54, "id": "9ec4c30c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] } ], "source": [ "session = fo.launch_app(dataset, auto = False)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a6021726", "metadata": {}, "source": [ "![coco instance segmentation](https://user-images.githubusercontent.com/12500356/239368987-cf187421-4645-47f6-a46d-84497ebc0939.png)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6d11ac43", "metadata": {}, "source": [ "It's that simple! I encourage you to try this - and the other segmentation types - out on your own dataset, and see how SAM performs.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "cec18e25feb9469b5ff1085a8097bdcd86db6a4ac301d6aeff87d0f3e7ce4ca5" } } }, "nbformat": 4, "nbformat_minor": 5 }