{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Teachable machine\n", "\n", "**Create a deep model for your own task (webcam/other set of images)**\n", "\n", "* **Set Data Pipeline** - DataSet, DataLoader, Transforms\n", "* **Build Neural Network** - Network Module, Loss\n", "* **Train Model** - Optimizer, Babysitting Learning\n", "* **Transfer Learning** - Feature extractor, Fine-tuning\n", "\n", "\"Teach a machine using your camera\" - [Experiment][teachable-experiment] / [YouTube Presentation][teachable-youtube]\n", "\n", "\n", " \n", "\n", "\n", "[teachable-youtube]:https://youtu.be/3BhkeY974Rg\n", "[teachable-experiment]:https://teachablemachine.withgoogle.com/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "print(\"Torch version:\", torch.__version__)\n", "\n", "import torchvision\n", "print(\"Torchvision version:\", torchvision.__version__)\n", "\n", "import numpy as np\n", "print(\"Numpy version:\", np.__version__)\n", "\n", "import matplotlib\n", "print(\"Matplotlib version:\", matplotlib.__version__)\n", "\n", "import PIL\n", "print(\"PIL version:\", PIL.__version__)\n", "\n", "import IPython\n", "print(\"IPython version:\", IPython.__version__)\n", "\n", "import cv2\n", "print('OpenCV version:', cv2.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup Matplotlib\n", "%matplotlib inline\n", "#%config InlineBackend.figure_format = 'retina' # If you have a retina screen\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Data Set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython import display\n", "import os, time\n", "\n", "# Path to write images\n", "img_path = os.path.join('images/normal')\n", "prefix = 'session1'\n", "\n", "# Connect to webcam\n", "if 'webcam' not in locals() or webcam is None:\n", " webcam = cv2.VideoCapture(0)\n", "\n", "try:\n", " # Try to read from the webcam\n", " webcam_found, _ = webcam.read()\n", "\n", " if webcam_found:\n", " # How many photos to save\n", " n_images = int(input(\"Number of photos: \"))\n", "\n", " # Create figure to display webcam\n", " fig = plt.figure()\n", " axis = fig.gca()\n", " \n", " # Collect images\n", " live_in = 3\n", " image_taken = 0\n", "\n", " while image_taken < n_images:\n", " # Take a picture with the webcam\n", " _, image = webcam.read()\n", "\n", " # Process it\n", " image = cv2.resize(image, (250, 250)) # Reduce size\n", " image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # To RGB\n", "\n", " # Plot it\n", " axis.cla()\n", " axis.imshow(image_rgb)\n", "\n", " if live_in == 0:\n", " # We are live!\n", " image_taken += 1\n", " axis.set_title('Click ! ({}/{})'.format(image_taken, n_images))\n", "\n", " # Save the image\n", " path = os.path.join(img_path, '{}-{}.png'.format(prefix, image_taken))\n", " cv2.imwrite(path, image)\n", "\n", " # Time before taking the next picture\n", " sleep_time = 0.2\n", "\n", " else:\n", " # We are not live\n", " axis.set_title(\"We're live in .. {}\".format(live_in))\n", " sleep_time = 1\n", " live_in -= 1\n", "\n", " display.clear_output(wait=True)\n", " display.display(fig)\n", "\n", " # Sleep\n", " time.sleep(sleep_time)\n", " \n", " # Clear output\n", " display.clear_output()\n", "\n", " else:\n", " print('Cannot read from webcam, do you have one connected?')\n", " \n", "except KeyboardInterrupt:\n", " # Clear output\n", " display.clear_output()\n", " \n", "finally: \n", " # Disconnect webcam\n", " del(webcam)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set Data Pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "\n", "# Data transformations\n", "normalize = transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406], # values for PyTorch models\n", " std=[0.229, 0.224, 0.225]\n", ")\n", "train_transform = transforms.Compose([\n", " transforms.RandomCrop(224),\n", " transforms.ToTensor(),\n", " normalize\n", "])\n", "valid_transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " normalize\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create data set\n", "trainset = torchvision.datasets.ImageFolder('images', train_transform)\n", "validset = torchvision.datasets.ImageFolder('images', valid_transform)\n", "\n", "classes = trainset.classes\n", "n_classes = len(classes)\n", "print('Classes:', classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data.sampler import SubsetRandomSampler\n", "\n", "# Define train/validation sets\n", "n_images = len(trainset) # number of images in our data set\n", "idx = np.arange(n_images) # idx: 0 .. (n_images - 1)\n", "np.random.shuffle(idx) # shuffle\n", "\n", "# Create train/validation samplers\n", "valid_size = 100\n", "train_sampler = SubsetRandomSampler(idx[:-valid_size])\n", "valid_sampler = SubsetRandomSampler(idx[-valid_size:])\n", "\n", "print('Train set:', len(train_sampler))\n", "print('Validation set:', len(valid_sampler))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create data loaders\n", "train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, sampler=train_sampler)\n", "valid_loader = torch.utils.data.DataLoader(validset, batch_size=4, sampler=valid_sampler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot a few samples\n", "train_iter = iter(train_loader)\n", "images, labels = next(train_iter)\n", "\n", "print('Classes:', ', '.join(classes[i] for i in labels))\n", "grid = torchvision.utils.make_grid(images, normalize=True)\n", "plt.imshow(grid.numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build Neural Network - Transfer learning\n", "\n", "**Can we reuse what's been learned on other tasks?** - Source [cs231n][cs231-transfer]\n", "\n", "> In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.\n", "\n", "**Transfer Learning Scenarios**\n", "\n", "* Pretrained network as a **Feature Extractor**\n", "* Adjust weights - **Fine-tuning**\n", "\n", "[cs231-transfer]:http://cs231n.github.io/transfer-learning/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def resnet_freezed():\n", " # Pretrained Network\n", " model = torchvision.models.resnet18(pretrained=True)\n", "\n", " # Freeze parameters\n", " for param in model.parameters():\n", " param.requires_grad = False\n", "\n", " # Classification layer\n", " model.fc = torch.nn.Linear(model.fc.in_features, len(classes))\n", " \n", " return model\n", "\n", "resnet_freezed()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "\n", "# Create model\n", "model = resnet_freezed()\n", "\n", "# Criterion and optimizer for \"training\"\n", "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01)\n", "\n", "# Backprop step\n", "def compute_loss(output, target):\n", " y_tensor = torch.LongTensor(target)\n", " y_variable = torch.autograd.Variable(y_tensor)\n", " return criterion(output, y_variable)\n", "\n", "def backpropagation(output, target):\n", " optimizer.zero_grad() # Clear the gradients\n", " loss = compute_loss(output, target) # Compute loss\n", " loss.backward() # Backpropagation\n", " optimizer.step() # Let the optimizer adjust our model\n", " return loss.data\n", "\n", "# Helper function\n", "def get_accuracy(output, y):\n", " predictions = torch.argmax(output, dim=1) # Max activation\n", " is_correct = np.equal(predictions, y)\n", " return is_correct.numpy().mean()\n", " \n", "# Create a figure to visualize the results\n", "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))\n", "\n", "def plot_learning():\n", " # Plot what the network learned\n", " fig.suptitle('Epoch {}, batch {:,}/{:,}'.format(epoch, batch, len(train_loader)))\n", " ax1.cla()\n", " ax2.cla()\n", " \n", " # Set titles\n", " if len(stats['val_t']) > 0:\n", " ax1.set_title('Loss, val: {:.3f}'.format(np.mean(stats['val_loss'][-10:])))\n", " ax2.set_title('Accuracy, val: {:.3f}'.format(np.mean(stats['val_acc'][-10:])))\n", " else:\n", " ax1.set_title('Loss')\n", " ax2.set_title('Accuracy')\n", " \n", " ax1.plot(stats['train_t'], stats['train_loss'], label='train')\n", " ax1.plot(stats['val_t'], stats['val_loss'], label='valid')\n", " ax1.legend()\n", " ax2.plot(stats['train_t'], stats['train_acc'], label='train')\n", " ax2.plot(stats['val_t'], stats['val_acc'], label='valid')\n", " ax2.set_ylim(0, 1)\n", " ax2.legend()\n", " \n", " # Jupyter trick\n", " IPython.display.clear_output(wait=True)\n", " IPython.display.display(fig)\n", " \n", "# Collect loss / accuracy values\n", "stats = defaultdict(list)\n", "t = 0 # Number of samples seen\n", "print_step = 10 # Refresh rate\n", "\n", "# Train Network\n", "epoch = 1\n", "do_training = True\n", "\n", "while do_training:\n", " # Set Model in \"training\" mode\n", " model.train()\n", " \n", " # Train by small batches of data\n", " for batch, (batch_X, batch_y) in enumerate(train_loader, 1):\n", " # Forward pass & backpropagation\n", " output = model(batch_X)\n", " loss = backpropagation(output, batch_y)\n", "\n", " # Log \"train\" stats\n", " stats['train_loss'].append(loss)\n", " stats['train_acc'].append(get_accuracy(output, batch_y))\n", " stats['train_t'].append(t)\n", "\n", " if t%print_step == 0:\n", " # Plot learning\n", " plot_learning()\n", "\n", " # Update t\n", " t += train_loader.batch_size\n", " \n", " # Set model in \"validation\" mode\n", " model.eval()\n", "\n", " # Log \"validation\" stats\n", " loss_vals, acc_vals = [], []\n", " for X, y in valid_loader:\n", " output = model(X)\n", " loss_vals.append(compute_loss(output, y).data)\n", " acc_vals.append(get_accuracy(output, y))\n", "\n", " stats['val_loss'].append(np.mean(loss_vals))\n", " stats['val_acc'].append(np.mean(acc_vals))\n", " stats['val_t'].append(t)\n", " \n", " # Plot learning\n", " plot_learning()\n", "\n", " # Should we continue?\n", " do_training = int(input('Continue training? 1 (yes) or 0 (no): '))\n", " epoch += 1\n", " \n", "# Clear output\n", "IPython.display.clear_output(wait=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tuning - Smaller learning rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def resnet():\n", " # Pretrained model\n", " model = torchvision.models.resnet18(pretrained=True)\n", "\n", " # Classification layer\n", " model.fc = torch.nn.Linear(model.fc.in_features, len(classes))\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save the model and class names\n", "state = {\n", " 'model': model,\n", " 'classes': classes\n", "}\n", "torch.save(state, os.path.join('data', 'webcam-model.p'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# Live Test!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load libraries\n", "import torch\n", "print(\"Torch version:\", torch.__version__)\n", "\n", "import torchvision\n", "print(\"Torchvision version:\", torchvision.__version__)\n", "\n", "import matplotlib\n", "print(\"Matplotlib version:\", matplotlib.__version__)\n", "\n", "import numpy as np\n", "print(\"Numpy version:\", np.__version__)\n", "\n", "import cv2\n", "print('OpenCV version:', cv2.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup Matplotlib\n", "%matplotlib inline\n", "#%config InlineBackend.figure_format = 'retina' # If you have a retina screen\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# Load Model\n", "state = torch.load(os.path.join('data', 'webcam-model.p'))\n", "model = state['model']\n", "classes = state['classes']\n", "print('Classes:', classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "\n", "# Define image transformation\n", "image_transform = transforms.Compose([\n", " transforms.ToPILImage(), # Convert webcam images to PIL format\n", " transforms.Resize((224, 224)), # Resize\n", " transforms.ToTensor(),\n", " transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406], # values for PyTorch models\n", " std=[0.229, 0.224, 0.225]\n", " )\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We will need some tools from PyTorch\n", "from torch.autograd import Variable\n", "import torch.nn as nn\n", "\n", "# Tools to display webcam feed\n", "from IPython import display\n", "import time\n", "\n", "# Connect to webcam\n", "if 'webcam' not in locals() or webcam is None:\n", " webcam = cv2.VideoCapture(0)\n", "\n", "try:\n", " # Try to read from the webcam\n", " webcam_found, _ = webcam.read()\n", "\n", " if webcam_found:\n", " # Create figure\n", " fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(6, 2))\n", "\n", " # Set network in \"evaluation\" mode\n", " model.eval()\n", "\n", " for i in range(100):\n", " # Take a picture with the webcam\n", " _, image = webcam.read()\n", "\n", " # Process it\n", " image = cv2.resize(image, (250, 250)) # Reduce size\n", " image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # To RGB\n", " image_pytorch = image_transform(image_rgb)\n", "\n", " # Classify image\n", " output = model(Variable(image_pytorch[None, :]))\n", " probs = nn.functional.softmax(output, 1).data.numpy()[0]\n", "\n", " # Plot the image\n", " ax1.cla()\n", " ax1.barh(np.arange(len(classes)), probs, height=0.5, tick_label=classes)\n", " ax1.set_xlim(0, 1)\n", " ax2.cla()\n", " ax2.imshow(image_rgb, aspect='auto')\n", "\n", " # Jupyter trick\n", " display.clear_output(wait=True)\n", " display.display(fig)\n", "\n", " # Rest a bit for CPU\n", " time.sleep(0.2)\n", "\n", " # Clear output\n", " display.clear_output()\n", "\n", " else:\n", " print('Cannot read from webcam, do you have one connected?')\n", "\n", "except KeyboardInterrupt:\n", " # Clear output\n", " display.clear_output()\n", " \n", "finally: \n", " # Disconnect webcam\n", " del(webcam)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }