{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-Wc92cWK-Aas" }, "source": [ "# Getting started with Owl-ViT\n", "\n", "In this notebook, we are going to run the [OWL-ViT](https://arxiv.org/abs/2205.06230) model (an open-vocabulary object detection model) by Google Research on scikit-image samples images. \n", "\n", "## OWL-ViT: A Quick Intro\n", "OWL-ViT is an open-vocabulary object detector. Given an image and one or multiple free-text queries, it finds objects matching the queries in the image. Unlike traditional object detection models, OWL-ViT is not trained on labeled object datasets and leverages multi-modal representations to perform open-vocabulary detection. \n", "\n", "OWL-ViT uses CLIP with a ViT-like Transformer as its backbone to get multi-modal visual and text features. To use CLIP for object detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection.\n", "\n", "![owlvit architecture](https://raw.githubusercontent.com/google-research/scenic/a41d24676f64a2158bfcd7cb79b0a87673aa875b/scenic/projects/owl_vit/data/owl_vit_schematic.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "uIcaig48T6yv" }, "source": [ "## Set-up environment\n", "\n", "First, we install the HuggingFace Transformers library (from source for now, as the model was recently added to the library and is under active development). This might take a few minutes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_XLma_DL3S9-", "outputId": "3d8e2639-90a0-4820-bf7f-04eee42e5258" }, "outputs": [], "source": [ "!pip install -q git+https://github.com/huggingface/transformers.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Optional:** Install Pillow, matplotlib and OpenCV if you are running this notebook locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install Pillow\n", "!pip install matplotlib\n", "!pip install opencv-python" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"zeroshot_object_detection_with_owlvit_notebook\", framework=\"pytorch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "QPrCVnimE0qR" }, "source": [ "## Load pre-trained model and processor\n", "\n", "Let's first apply the image preprocessing and tokenize the text queries using `OwlViTProcessor`. The processor will resize the image(s), scale it between [0-1] range and normalize it across the channels using the mean and standard deviation specified in the original codebase.\n", "\n", "\n", "Text queries are tokenized using a CLIP tokenizer and stacked to output tensors of shape [batch_size * num_max_text_queries, sequence_length]. If you are inputting more than one set of (image, text prompt/s), num_max_text_queries is the maximum number of text queries per image across the batch. Input samples with fewer text queries are padded. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AD8DXCnJ7faH", "outputId": "3207f732-be55-46a2-ce3e-346e32efeac9" }, "outputs": [], "source": [ "from transformers import OwlViTProcessor, OwlViTForObjectDetection\n", "\n", "model = OwlViTForObjectDetection.from_pretrained(\"google/owlvit-base-patch32\")\n", "processor = OwlViTProcessor.from_pretrained(\"google/owlvit-base-patch32\")" ] }, { "cell_type": "markdown", "metadata": { "id": "7QN-vURe3euV" }, "source": [ "## Preprocess input image and text queries\n", "\n", "Let's use the image of astronaut Eileen Collins to test OWL-ViT. It's part of the [NASA](https://www.nasa.gov/multimedia/imagegallery/index.html) Great Images dataset.\n", "\n", "You can use one or multiple text prompts per image to search for the target object(s). Let's start with a simple example where we search for multiple objects in a single image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 529 }, "id": "LOX1__3nrezW", "outputId": "4435e2c9-b79d-4aaf-d7ef-b1cf2c239d01" }, "outputs": [], "source": [ "import cv2\n", "import skimage\n", "import numpy as np\n", "from PIL import Image\n", "\n", "# Download sample image\n", "image = skimage.data.astronaut()\n", "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n", "\n", "# Text queries to search the image for\n", "text_queries = [\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"]\n", "\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Fu_0bM2Mz0R" }, "outputs": [], "source": [ "import torch\n", "\n", "# Use GPU if available\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S6lQ2d9uodiv", "outputId": "f5634352-d44e-47c4-f66c-4ad51522efdb" }, "outputs": [], "source": [ "# Process image and text inputs\n", "inputs = processor(text=text_queries, images=image, return_tensors=\"pt\").to(device)\n", "\n", "# Print input names and shapes\n", "for key, val in inputs.items():\n", " print(f\"{key}: {val.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ten5ZJQsoUbE" }, "source": [ "## Forward pass\n", "\n", "Now we can pass the inputs to our OWL-ViT model to get object detection predictions. \n", "\n", "`OwlViTForObjectDetection` model outputs the prediction logits, boundary boxes and class embeddings, along with the image and text embeddings outputted by the `OwlViTModel`, which is the CLIP backbone." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "t7oKBoKQEz_w", "outputId": "5deb56c6-c302-4175-a854-d4d6e7b7d903" }, "outputs": [], "source": [ "# Set model in evaluation mode\n", "model = model.to(device)\n", "model.eval()\n", "\n", "# Get predictions\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", "\n", "for k, val in outputs.items():\n", " if k not in {\"text_model_output\", \"vision_model_output\"}:\n", " print(f\"{k}: shape of {val.shape}\")\n", "\n", "print(\"\\nText model outputs\")\n", "for k, val in outputs.text_model_output.items():\n", " print(f\"{k}: shape of {val.shape}\")\n", "\n", "print(\"\\nVision model outputs\")\n", "for k, val in outputs.vision_model_output.items():\n", " print(f\"{k}: shape of {val.shape}\") " ] }, { "cell_type": "markdown", "metadata": { "id": "h2C5dkQ8sBVZ" }, "source": [ "## Draw predictions on image\n", "\n", "Let's draw the predictions / found objects on the input image. Remember the found objects correspond to the input text queries. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PNYQe2xbl5vl" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "from transformers.image_utils import ImageFeatureExtractionMixin\n", "mixin = ImageFeatureExtractionMixin()\n", "\n", "# Load example image\n", "image_size = model.config.vision_config.image_size\n", "image = mixin.resize(image, image_size)\n", "input_image = np.asarray(image).astype(np.float32) / 255.0\n", "\n", "# Threshold to eliminate low probability predictions\n", "score_threshold = 0.1\n", "\n", "# Get prediction logits\n", "logits = torch.max(outputs[\"logits\"][0], dim=-1)\n", "scores = torch.sigmoid(logits.values).cpu().detach().numpy()\n", "\n", "# Get prediction labels and boundary boxes\n", "labels = logits.indices.cpu().detach().numpy()\n", "boxes = outputs[\"pred_boxes\"][0].cpu().detach().numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 480 }, "id": "Xc6YjWXBl_vK", "outputId": "23ed2b75-59f0-461b-ce8d-48574de96f92" }, "outputs": [], "source": [ "def plot_predictions(input_image, text_queries, scores, boxes, labels):\n", " fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n", " ax.imshow(input_image, extent=(0, 1, 1, 0))\n", " ax.set_axis_off()\n", "\n", " for score, box, label in zip(scores, boxes, labels):\n", " if score < score_threshold:\n", " continue\n", "\n", " cx, cy, w, h = box\n", " ax.plot([cx-w/2, cx+w/2, cx+w/2, cx-w/2, cx-w/2],\n", " [cy-h/2, cy-h/2, cy+h/2, cy+h/2, cy-h/2], \"r\")\n", " ax.text(\n", " cx - w / 2,\n", " cy + h / 2 + 0.015,\n", " f\"{text_queries[label]}: {score:1.2f}\",\n", " ha=\"left\",\n", " va=\"top\",\n", " color=\"red\",\n", " bbox={\n", " \"facecolor\": \"white\",\n", " \"edgecolor\": \"red\",\n", " \"boxstyle\": \"square,pad=.3\"\n", " })\n", " \n", "plot_predictions(input_image, text_queries, scores, boxes, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "WBgWR3yONIwT" }, "source": [ "## Batch processing\n", "We can also pass in multiple sets of images and text queries to search for different (or same) objects in different images. Let's download an image of a coffee mug to process alongside the astronaut image.\n", "\n", "For batch processing, we need to input text queries as a nested list to `OwlViTProcessor` and images as lists of (PIL images or PyTorch tensors or NumPy arrays)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 417 }, "id": "fZJrboQlswuA", "outputId": "ac0c0d96-f84f-4938-dadc-9a80669cb5a7" }, "outputs": [], "source": [ "# Download the coffee mug image\n", "image = skimage.data.coffee()\n", "image = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NVLvC8IrNPQW", "outputId": "a0c49462-2137-4a02-8971-f40765ef35fa" }, "outputs": [], "source": [ "# Preprocessing\n", "images = [skimage.data.astronaut(), skimage.data.coffee()]\n", "images = [Image.fromarray(np.uint8(img)).convert(\"RGB\") for img in images]\n", "\n", "# Nexted list of text queries to search each image for\n", "text_queries = [[\"human face\", \"rocket\", \"nasa badge\", \"star-spangled banner\"], [\"coffee mug\", \"spoon\", \"plate\"]]\n", "\n", "# Process image and text inputs\n", "inputs = processor(text=text_queries, images=images, return_tensors=\"pt\").to(device)\n", "\n", "# Print input names and shapes\n", "for key, val in inputs.items():\n", " print(f\"{key}: {val.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "i9Q2ik1kNZ1j" }, "source": [ "**Note:** Notice the size of the `input_ids `and `attention_mask` is `[batch_size * num_max_text_queries, max_length]`. Max_length is set to 16 for all OWL-ViT models." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ApJXe-2sNR9-", "outputId": "61756193-4fae-4d9d-90a2-13da26ae185e" }, "outputs": [], "source": [ "# Get predictions\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", "\n", "for k, val in outputs.items():\n", " if k not in {\"text_model_output\", \"vision_model_output\"}:\n", " print(f\"{k}: shape of {val.shape}\")\n", " \n", "print(\"\\nText model outputs\")\n", "for k, val in outputs.text_model_output.items():\n", " print(f\"{k}: shape of {val.shape}\")\n", "\n", "print(\"\\nVision model outputs\")\n", "for k, val in outputs.vision_model_output.items():\n", " print(f\"{k}: shape of {val.shape}\") " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's plot the predictions for the second image\n", "image_idx = 1\n", "image_size = model.config.vision_config.image_size\n", "image = mixin.resize(images[image_idx], image_size)\n", "input_image = np.asarray(image).astype(np.float32) / 255.0\n", "\n", "# Threshold to eliminate low probability predictions\n", "score_threshold = 0.1\n", "\n", "# Get prediction logits\n", "logits = torch.max(outputs[\"logits\"][image_idx], dim=-1)\n", "scores = torch.sigmoid(logits.values).cpu().detach().numpy()\n", "\n", "# Get prediction labels and boundary boxes\n", "labels = logits.indices.cpu().detach().numpy()\n", "boxes = outputs[\"pred_boxes\"][image_idx].cpu().detach().numpy()\n", "\n", "plot_predictions(input_image, text_queries[image_idx], scores, boxes, labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "n6VYaXGbNfwU" }, "source": [ "## Post-processing model predictions\n", "Notice how we printed the output predictions on the resized input image. This is because OWL-ViT outputs normalized box coordinates in `[cx, cy, w, h]` format assuming a fixed input image size. We can use the `OwlViTProcessor`'s convenient post_process() method to convert the model outputs to a COCO API format and retrieve rescaled coordinates (with respect to the original image sizes) in `[x0, y0, x1, y1]` format. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "R3xneq28NcUm", "outputId": "802a1707-e309-4c96-88bd-a8882afee8d9" }, "outputs": [], "source": [ "# Target image sizes (height, width) to rescale box predictions [batch_size, 2]\n", "target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)\n", "\n", "# Convert outputs (bounding boxes and class logits) to COCO API\n", "results = processor.post_process(outputs=outputs, target_sizes=target_sizes)\n", "\n", "# Loop over predictions for each image in the batch\n", "for i in range(len(images)):\n", " print(f\"\\nProcessing image {i}\")\n", " text = text_queries[i]\n", " boxes, scores, labels = results[i][\"boxes\"], results[i][\"scores\"], results[i][\"labels\"]\n", "\n", " score_threshold = 0.1\n", " for box, score, label in zip(boxes, scores, labels):\n", " box = [round(i, 2) for i in box.tolist()]\n", "\n", " if score >= score_threshold:\n", " print(f\"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bonus: one-shot / image-guided object detection\n", "Instead of performing zero-shot detection with text inputs, we can use the `OwlViTForObjectDetection.image_guided_detection()` method to query an input image with a query / example image and detect similar looking objects. To do this, we simply pass in `query_images` instead of text to the processor to get the `query_pixel_values`. Note that, unlike text input, `OwlViTProcessor` expects one query image per target image we'd like to query for similar objects. We will also see that the output and post-processing of one-shot object detection is very similar to the zero-shot / text-guided detection.\n", "\n", "Let's try this out by querying an image with cats with another random cat image. For this part of the demo, we will perform image-guided object detection, post-process the results and display the predicted boundary boxes on the original input image using OpenCV." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "import requests\n", "from matplotlib import rcParams\n", "\n", "# Set figure size\n", "%matplotlib inline\n", "rcParams['figure.figsize'] = 11 ,8\n", "\n", "# Input image\n", "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n", "image = Image.open(requests.get(url, stream=True).raw)\n", "target_sizes = torch.Tensor([image.size[::-1]])\n", "\n", "# Query image\n", "query_url = \"http://images.cocodataset.org/val2017/000000058111.jpg\"\n", "query_image = Image.open(requests.get(query_url, stream=True).raw)\n", "\n", "# Display input image and query image\n", "fig, ax = plt.subplots(1,2)\n", "ax[0].imshow(image)\n", "ax[1].imshow(query_image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Process input and query image\n", "inputs = processor(images=image, query_images=query_image, return_tensors=\"pt\").to(device)\n", "\n", "# Print input names and shapes\n", "for key, val in inputs.items():\n", " print(f\"{key}: {val.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get predictions\n", "with torch.no_grad():\n", " outputs = model.image_guided_detection(**inputs)\n", "\n", "for k, val in outputs.items():\n", " if k not in {\"text_model_output\", \"vision_model_output\"}:\n", " print(f\"{k}: shape of {val.shape}\")\n", "\n", "print(\"\\nVision model outputs\")\n", "for k, val in outputs.vision_model_output.items():\n", " print(f\"{k}: shape of {val.shape}\") " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)\n", "outputs.logits = outputs.logits.cpu()\n", "outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() \n", "\n", "results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes)\n", "boxes, scores = results[0][\"boxes\"], results[0][\"scores\"]\n", "\n", "# Draw predicted bounding boxes\n", "for box, score in zip(boxes, scores):\n", " box = [int(i) for i in box.tolist()]\n", "\n", " img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)\n", " if box[3] + 25 > 768:\n", " y = box[3] - 10\n", " else:\n", " y = box[3] + 25 \n", " \n", "plt.imshow(img[:,:,::-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "OWL-ViT-inference example.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 1 }