{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "N6ZDpd9XzFeN" }, "source": [ "##### Copyright 2018 The TensorFlow Hub Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "id": "KUu4vOt5zI9d" }, "outputs": [], "source": [ "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "CxmDMK4yupqg" }, "source": [ "# Object Detection\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub models\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Sy553YSVmYiK" }, "source": [ "This Colab demonstrates use of a TF-Hub module trained to perform object detection." ] }, { "cell_type": "markdown", "metadata": { "id": "v4XGxDrCkeip" }, "source": [ "## Setup\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "id": "6cPY9Ou4sWs_" }, "outputs": [], "source": [ "#@title Imports and function definitions\n", "\n", "# For running inference on the TF-Hub module.\n", "import tensorflow as tf\n", "\n", "import tensorflow_hub as hub\n", "\n", "# For downloading the image.\n", "import matplotlib.pyplot as plt\n", "import tempfile\n", "from six.moves.urllib.request import urlopen\n", "from six import BytesIO\n", "\n", "# For drawing onto the image.\n", "import numpy as np\n", "from PIL import Image\n", "from PIL import ImageColor\n", "from PIL import ImageDraw\n", "from PIL import ImageFont\n", "from PIL import ImageOps\n", "\n", "# For measuring the inference time.\n", "import time\n", "\n", "# Print Tensorflow version\n", "print(tf.__version__)\n", "\n", "# Check available GPU devices.\n", "print(\"The following GPU devices are available: %s\" % tf.test.gpu_device_name())" ] }, { "cell_type": "markdown", "metadata": { "id": "ZGkrXGy62409" }, "source": [ "## Example use" ] }, { "cell_type": "markdown", "metadata": { "id": "vlA3CftFpRiW" }, "source": [ "### Helper functions for downloading images and for visualization.\n", "\n", "Visualization code adapted from [TF object detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/visualization_utils.py) for the simplest required functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "D9IwDpOtpIHW" }, "outputs": [], "source": [ "def display_image(image):\n", " fig = plt.figure(figsize=(20, 15))\n", " plt.grid(False)\n", " plt.imshow(image)\n", "\n", "\n", "def download_and_resize_image(url, new_width=256, new_height=256,\n", " display=False):\n", " _, filename = tempfile.mkstemp(suffix=\".jpg\")\n", " response = urlopen(url)\n", " image_data = response.read()\n", " image_data = BytesIO(image_data)\n", " pil_image = Image.open(image_data)\n", " pil_image = ImageOps.fit(pil_image, (new_width, new_height), Image.LANCZOS)\n", " pil_image_rgb = pil_image.convert(\"RGB\")\n", " pil_image_rgb.save(filename, format=\"JPEG\", quality=90)\n", " print(\"Image downloaded to %s.\" % filename)\n", " if display:\n", " display_image(pil_image)\n", " return filename\n", "\n", "\n", "def draw_bounding_box_on_image(image,\n", " ymin,\n", " xmin,\n", " ymax,\n", " xmax,\n", " color,\n", " font,\n", " thickness=4,\n", " display_str_list=()):\n", " \"\"\"Adds a bounding box to an image.\"\"\"\n", " draw = ImageDraw.Draw(image)\n", " im_width, im_height = image.size\n", " (left, right, top, bottom) = (xmin * im_width, xmax * im_width,\n", " ymin * im_height, ymax * im_height)\n", " draw.line([(left, top), (left, bottom), (right, bottom), (right, top),\n", " (left, top)],\n", " width=thickness,\n", " fill=color)\n", "\n", " # If the total height of the display strings added to the top of the bounding\n", " # box exceeds the top of the image, stack the strings below the bounding box\n", " # instead of above.\n", " display_str_heights = [font.getbbox(ds)[3] for ds in display_str_list]\n", " # Each display_str has a top and bottom margin of 0.05x.\n", " total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)\n", "\n", " if top > total_display_str_height:\n", " text_bottom = top\n", " else:\n", " text_bottom = top + total_display_str_height\n", " # Reverse list and print from bottom to top.\n", " for display_str in display_str_list[::-1]:\n", " bbox = font.getbbox(display_str)\n", " text_width, text_height = bbox[2], bbox[3]\n", " margin = np.ceil(0.05 * text_height)\n", " draw.rectangle([(left, text_bottom - text_height - 2 * margin),\n", " (left + text_width, text_bottom)],\n", " fill=color)\n", " draw.text((left + margin, text_bottom - text_height - margin),\n", " display_str,\n", " fill=\"black\",\n", " font=font)\n", " text_bottom -= text_height - 2 * margin\n", "\n", "\n", "def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):\n", " \"\"\"Overlay labeled boxes on an image with formatted scores and label names.\"\"\"\n", " colors = list(ImageColor.colormap.values())\n", "\n", " try:\n", " font = ImageFont.truetype(\"/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf\",\n", " 25)\n", " except IOError:\n", " print(\"Font not found, using default font.\")\n", " font = ImageFont.load_default()\n", "\n", " for i in range(min(boxes.shape[0], max_boxes)):\n", " if scores[i] >= min_score:\n", " ymin, xmin, ymax, xmax = tuple(boxes[i])\n", " display_str = \"{}: {}%\".format(class_names[i].decode(\"ascii\"),\n", " int(100 * scores[i]))\n", " color = colors[hash(class_names[i]) % len(colors)]\n", " image_pil = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n", " draw_bounding_box_on_image(\n", " image_pil,\n", " ymin,\n", " xmin,\n", " ymax,\n", " xmax,\n", " color,\n", " font,\n", " display_str_list=[display_str])\n", " np.copyto(image, np.array(image_pil))\n", " return image" ] }, { "cell_type": "markdown", "metadata": { "id": "D19UCu9Q2-_8" }, "source": [ "## Apply module\n", "\n", "Load a public image from Open Images v4, save locally, and display." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "id": "YLWNhjUY1mhg" }, "outputs": [], "source": [ "# By Heiko Gorski, Source: https://commons.wikimedia.org/wiki/File:Naxos_Taverna.jpg\n", "image_url = \"https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg\" #@param\n", "downloaded_image_path = download_and_resize_image(image_url, 1280, 856, True)" ] }, { "cell_type": "markdown", "metadata": { "id": "t-VdfLbC1w51" }, "source": [ "Pick an object detection module and apply on the downloaded image. Modules:\n", "* **FasterRCNN+InceptionResNet V2**: high accuracy,\n", "* **ssd+mobilenet V2**: small and fast." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uazJ5ASc2_QE" }, "outputs": [], "source": [ "module_handle = \"https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1\" #@param [\"https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1\", \"https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1\"]\n", "\n", "detector = hub.load(module_handle).signatures['default']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "znW8Fq1EC0x7" }, "outputs": [], "source": [ "def load_img(path):\n", " img = tf.io.read_file(path)\n", " img = tf.image.decode_jpeg(img, channels=3)\n", " return img" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kwGJV96WWBLH" }, "outputs": [], "source": [ "def run_detector(detector, path):\n", " img = load_img(path)\n", "\n", " converted_img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]\n", " start_time = time.time()\n", " result = detector(converted_img)\n", " end_time = time.time()\n", "\n", " result = {key:value.numpy() for key,value in result.items()}\n", "\n", " print(\"Found %d objects.\" % len(result[\"detection_scores\"]))\n", " print(\"Inference time: \", end_time-start_time)\n", "\n", " image_with_boxes = draw_boxes(\n", " img.numpy(), result[\"detection_boxes\"],\n", " result[\"detection_class_entities\"], result[\"detection_scores\"])\n", "\n", " display_image(image_with_boxes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vchaUW1XDodD" }, "outputs": [], "source": [ "run_detector(detector, downloaded_image_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "WUUY3nfRX7VF" }, "source": [ "### More images\n", "Perform inference on some additional images with time tracking.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rubdr2JXfsa1" }, "outputs": [], "source": [ "image_urls = [\n", " # Source: https://commons.wikimedia.org/wiki/File:The_Coleoptera_of_the_British_islands_(Plate_125)_(8592917784).jpg\n", " \"https://upload.wikimedia.org/wikipedia/commons/1/1b/The_Coleoptera_of_the_British_islands_%28Plate_125%29_%288592917784%29.jpg\",\n", " # By Américo Toledano, Source: https://commons.wikimedia.org/wiki/File:Biblioteca_Maim%C3%B3nides,_Campus_Universitario_de_Rabanales_007.jpg\n", " \"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0d/Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg/1024px-Biblioteca_Maim%C3%B3nides%2C_Campus_Universitario_de_Rabanales_007.jpg\",\n", " # Source: https://commons.wikimedia.org/wiki/File:The_smaller_British_birds_(8053836633).jpg\n", " \"https://upload.wikimedia.org/wikipedia/commons/0/09/The_smaller_British_birds_%288053836633%29.jpg\",\n", " ]\n", "\n", "def detect_img(image_url):\n", " start_time = time.time()\n", " image_path = download_and_resize_image(image_url, 640, 480)\n", " run_detector(detector, image_path)\n", " end_time = time.time()\n", " print(\"Inference time:\",end_time-start_time)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "otPnrxMKIrj5" }, "outputs": [], "source": [ "detect_img(image_urls[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H5F7DkD5NtOx" }, "outputs": [], "source": [ "detect_img(image_urls[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DZ18R7dWNyoU" }, "outputs": [], "source": [ "detect_img(image_urls[2])" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "object_detection.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }