{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "llx_cGpoAyAq" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "colab": {}, "colab_type": "code", "id": "5MAYU_6KA0Kt" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "35lZ8kr3UcsB" }, "source": [ "# Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "MfBg1C5NB3X0" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "BZP72A6eFw74" }, "source": [ "## Overview\n", "\n", "This tutorial demonstrates manual image manipulations and augmentation using `tf.image`.\n", "\n", "Data augmentation is a common technique to improve results and avoid overfitting, see [Overfitting and Underfitting](../keras/overfit_and_underfit.ipynb) for others." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8sZIVqk7HvnC" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "3TS4vBJBd8jY" }, "outputs": [], "source": [ "!pip install git+https://github.com/tensorflow/docs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "rdP8EQbPsyRA" }, "outputs": [], "source": [ "import urllib\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras.datasets import mnist\n", "from tensorflow.keras import layers\n", "AUTOTUNE = tf.data.experimental.AUTOTUNE\n", "\n", "import tensorflow_docs as tfdocs\n", "import tensorflow_docs.plots\n", "\n", "import tensorflow_datasets as tfds\n", "\n", "import PIL.Image\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "mpl.rcParams['figure.figsize'] = (12, 5)\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "__EuXwM-8uth" }, "source": [ "Let's check the data augmentation features on an image and then augment a whole dataset later to train a model." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "frBSdODBLOOI" }, "source": [ "Download [this image](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg), by Von.grzanka, for augmentation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "s5ThIwG8KqzI" }, "outputs": [], "source": [ "image_path = tf.keras.utils.get_file(\"cat.jpg\", \"https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg\")\n", "PIL.Image.open(image_path)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-Ec3bGonGDCF" }, "source": [ "Read and decode the image to tensor format." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "cdCoB8b-uZjf" }, "outputs": [], "source": [ "image_string=tf.io.read_file(image_path)\n", "image=tf.image.decode_jpeg(image_string,channels=3)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "isGwyT0386yi" }, "source": [ "A function to visualize and compare the original and augmented image side by side." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "FKnRfw2dvyql" }, "outputs": [], "source": [ "def visualize(original, augmented):\n", " fig = plt.figure()\n", " plt.subplot(1,2,1)\n", " plt.title('Original image')\n", " plt.imshow(original)\n", "\n", " plt.subplot(1,2,2)\n", " plt.title('Augmented image')\n", " plt.imshow(augmented)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jYLzpEOhGqWY" }, "source": [ "## Augment a single image" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8IiXghY99Bo6" }, "source": [ "### Flipping the image\n", "Flip the image either vertically or horizontally." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "X14VjLlFxnvZ" }, "outputs": [], "source": [ "flipped = tf.image.flip_left_right(image)\n", "visualize(image, flipped)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ObsvSmu99MfC" }, "source": [ "### Grayscale the image\n", "Grayscale an image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "mnqQA2ubyo6O" }, "outputs": [], "source": [ "grayscaled = tf.image.rgb_to_grayscale(image)\n", "visualize(image, tf.squeeze(grayscaled))\n", "plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "juI4A4HF9gYc" }, "source": [ "### Saturate the image\n", "Saturate an image by providing a saturation factor." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "tiTUhw-gzCJW" }, "outputs": [], "source": [ "saturated = tf.image.adjust_saturation(image, 3)\n", "visualize(image, saturated)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "E82CqomP9qcR" }, "source": [ "### Change image brightness\n", "Change the brightness of image by providing a brightness factor." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "05dA6uEtzfyd" }, "outputs": [], "source": [ "bright = tf.image.adjust_brightness(image, 0.4)\n", "visualize(image, bright)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5_0kMbmS91x6" }, "source": [ "### Rotate the image\n", "Rotate an image by 90 degrees." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "edNoQzhszxo8" }, "outputs": [], "source": [ "rotated = tf.image.rot90(image)\n", "visualize(image, rotated)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bomBnFWp9895" }, "source": [ "### Center crop the image\n", "Crop the image from center upto the image part you desire." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "fvgz_6t21dq2" }, "outputs": [], "source": [ "cropped = tf.image.central_crop(image, central_fraction=0.5)\n", "visualize(image,cropped)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8W5E_c7o-H96" }, "source": [ "See the `tf.image` reference for details about available augmentation options." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "92lBGZSQ-1Tx" }, "source": [ "## Augment a dataset and train a model with it" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lrDez4xIX9Ss" }, "source": [ "Train a model on an augmented dataset.\n", "\n", "Note: The problem solved here is somewhat artificial. It trains a densely connected network to be shift invariant by jittering the input images. It's much more efficient to use convolutional layers instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "mazlEonS_gTR" }, "outputs": [], "source": [ "dataset, info = tfds.load('mnist', as_supervised=True, with_info=True)\n", "train_dataset, test_dataset = dataset['train'], dataset['test']\n", "\n", "num_train_examples= info.splits['train'].num_examples" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "011caOa0YCz5" }, "source": [ "Write a function to augment the images. Map it over the the dataset. This returns a dataset that augments the data on the fly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "3oaSV5QcDS8p" }, "outputs": [], "source": [ "def convert(image, label):\n", " image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n", " return image, label\n", "\n", "def augment(image,label):\n", " image,label = convert(image, label)\n", " image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n", " image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding\n", " image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28\n", " image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness\n", "\n", " return image,label" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "xNROtM5uqSjg" }, "outputs": [], "source": [ "BATCH_SIZE = 64\n", "# Only use a subset of the data so it's easier to overfit, for this tutorial\n", "NUM_EXAMPLES = 2048" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dEq0VnUIy-8l" }, "source": [ "Create the augmented dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "jWJpV6JOqN_7" }, "outputs": [], "source": [ "augmented_train_batches = (\n", " train_dataset\n", " # Only train on a subset, so you can quickly see the effect.\n", " .take(NUM_EXAMPLES)\n", " .cache()\n", " .shuffle(num_train_examples//4)\n", " # The augmentation is added here.\n", " .map(augment, num_parallel_calls=AUTOTUNE)\n", " .batch(BATCH_SIZE)\n", " .prefetch(AUTOTUNE)\n", ") " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PGm5X_77zE3q" }, "source": [ "And a non-augmented one for comparison." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "td4yU920qOgU" }, "outputs": [], "source": [ "non_augmented_train_batches = (\n", " train_dataset\n", " # Only train on a subset, so you can quickly see the effect.\n", " .take(NUM_EXAMPLES)\n", " .cache()\n", " .shuffle(num_train_examples//4)\n", " # No augmentation.\n", " .map(convert, num_parallel_calls=AUTOTUNE)\n", " .batch(BATCH_SIZE)\n", " .prefetch(AUTOTUNE)\n", ") " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zkLkgEudzKfc" }, "source": [ "Setup the validation dataset. This doesn't change whether or not you're using the augmentation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "6eqKD4zOqpvE" }, "outputs": [], "source": [ "validation_batches = (\n", " test_dataset\n", " .map(convert, num_parallel_calls=AUTOTUNE)\n", " .batch(2*BATCH_SIZE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Yi9TIwR-ZIOi" }, "source": [ "Create and compile the model. The model is a two layered, fully-connected neural network without convolution." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "hHhkA4Q0CsHx" }, "outputs": [], "source": [ "def make_model():\n", " model = tf.keras.Sequential([\n", " layers.Flatten(input_shape=(28, 28, 1)),\n", " layers.Dense(4096, activation='relu'),\n", " layers.Dense(4096, activation='relu'),\n", " layers.Dense(10)\n", " ])\n", " model.compile(optimizer = 'adam',\n", " loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "P0rciou3ZWwy" }, "source": [ "Train the model, **without** augmentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "z8X8CpqvNhG9" }, "outputs": [], "source": [ "model_without_aug = make_model()\n", "\n", "no_aug_history = model_without_aug.fit(non_augmented_train_batches, epochs=50, validation_data=validation_batches)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yH3LUd1prFWe" }, "source": [ "Train it again with augmentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "z4RqS9FWrEiS" }, "outputs": [], "source": [ "model_with_aug = make_model()\n", "\n", "aug_history = model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UEqeeNsHZaC5" }, "source": [ "## Conclusion:\n", "\n", "In this example the augmented model converges to an accuracy ~95% on validation set. This is slightly higher (+1%) than the model trained without data augmentation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ZnkliKKm0VVw" }, "outputs": [], "source": [ "plotter = tfdocs.plots.HistoryPlotter()\n", "plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"accuracy\")\n", "plt.title(\"Accuracy\")\n", "plt.ylim([0.75,1])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "cAaUjNhI3j7k" }, "source": [ "In terms of loss, the non-augmented model is obviously in the overfitting regime. The augmented model, while a few epoch slower, is still training correctly and clearly not overfitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "1-dTZpc4zq0-" }, "outputs": [], "source": [ "plotter = tfdocs.plots.HistoryPlotter()\n", "plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"loss\")\n", "plt.title(\"Loss\")\n", "plt.ylim([0,1])" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "data_augmentation.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }