{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Scavenger hunt\n", "\n", "\"Identify emojis in the real world with your phone’s camera - [Experiment][scavenger-experiment] / [YouTube Presentation][scavenger-youtube]\n", "\n", "\n", " \n", "\n", "\n", "[scavenger-youtube]:https://youtu.be/jr3q_9pJBr8\n", "[scavenger-experiment]:https://emojiscavengerhunt.withgoogle.com/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "print(\"Torch version:\", torch.__version__)\n", "\n", "import torchvision\n", "print(\"Torchvision version:\", torchvision.__version__)\n", "\n", "import numpy as np\n", "print(\"Numpy version:\", np.__version__)\n", "\n", "import matplotlib\n", "print(\"Matplotlib version:\", matplotlib.__version__)\n", "\n", "import PIL\n", "print(\"PIL version:\", PIL.__version__)\n", "\n", "import IPython\n", "print(\"IPython version:\", IPython.__version__)\n", "\n", "import cv2\n", "print('OpenCV version:', cv2.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup Matplotlib\n", "%matplotlib inline\n", "#%config InlineBackend.figure_format = 'retina' # If you have a retina screen\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrained models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load a pretrained model\n", "model = torchvision.models.resnet18(pretrained=True)\n", "\n", "# Set model in \"evaluation\" model\n", "_ = model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "\n", "# Define the input pipeline\n", "pipeline = transforms.Compose([\n", " transforms.ToPILImage(), # Convert webcam images to PIL format\n", " transforms.ToTensor(), # Convert to PyTorch Tensor\n", " transforms.Normalize( # Normalize using predefined values\n", " mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225]\n", " )\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "# Load classes\n", "imagenet_classes = json.load(open('imagenet-classes.json'))\n", "\n", "# Function to transform label(s) to id(s)\n", "to_id = {label: int(i) for i, label in imagenet_classes.items()} # label -> id\n", "to_ids = lambda labels: [to_id[label] for label in labels] # labels -> ids\n", "\n", "# Define items\n", "items = {\n", " 'cup': ['cup', 'goblet', 'coffee_mug', 'espresso', 'eggnog', 'red_wine', 'beer_glass'],\n", " 'clock': ['digital_clock', 'digital_watch', 'analog_clock', 'wall_clock', 'stopwatch'],\n", " 'bottle': ['water_bottle', 'pop_bottle', 'wine_bottle', 'beer_bottle']\n", "}\n", "items_names = list(items.keys())\n", "\n", "# The goal is to find all items\n", "was_found = {c: False for c in items.keys()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test with webcam feed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython import display\n", "import time\n", "\n", "# Connect to webcam\n", "if 'webcam' not in locals() or webcam is None:\n", " webcam = cv2.VideoCapture(0)\n", "\n", "try:\n", " # Try to read from the webcam\n", " webcam_found, _ = webcam.read()\n", "\n", " if webcam_found:\n", " # Create figure\n", " fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))\n", " \n", " for i in range(1000):\n", " # Take a picture with the webcam\n", " _, image = webcam.read()\n", "\n", " # Process it\n", " image = cv2.resize(image, (224, 224)) # Resize to fit the model\n", " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # To RGB\n", "\n", " # Classify image\n", " image_pytorch = pipeline(image)\n", " output = model(torch.autograd.Variable(image_pytorch[np.newaxis, :]))\n", " all_probs = torch.nn.functional.softmax(output, 1).view(-1).data.numpy()\n", "\n", " # Get probabilities for our items\n", " probs = [all_probs[to_ids(labels)].max() for c, labels in items.items()]\n", " \n", " # Did we find an object?\n", " if max(probs) > 0.2:\n", " found_class = items_names[np.argmax(probs)]\n", " was_found[found_class] = True\n", "\n", " # Plot the image\n", " ax1.cla()\n", " ax1.barh(np.arange(len(items)), probs, height=0.5, tick_label=['{} [{}]'.format(c, '✓' if done else '✗') for c, done in was_found.items()])\n", " ax1.set_xlim(0, 1)\n", " ax2.cla()\n", " ax2.imshow(image, aspect='auto')\n", " ax2.set_title('webcam')\n", " \n", " # Set title\n", " if np.all(list(was_found.values())):\n", " ax1.set_title(r'Bravo \\°$\\smile$°/ !')\n", " else:\n", " ax1.set_title('Find a ..')\n", "\n", " # Jupyter trick\n", " display.clear_output(wait=True)\n", " display.display(fig)\n", "\n", " # Rest a bit for CPU\n", " time.sleep(0.2)\n", "\n", " # Clear output\n", " display.clear_output()\n", "\n", " else:\n", " print('Cannot read from webcam, do you have one connected?')\n", " \n", "except KeyboardInterrupt:\n", " # Clear output\n", " display.clear_output()\n", " \n", "finally: \n", " # Disconnect webcam\n", " del(webcam)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }