{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "4EFY9e5wRn7v" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "pkTRazeVRwDe" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "VyOckJu6Rs-i" }, "source": [ "# Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "id": "0HEsULqDR7AH" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "PxIOE5RnSQtj" }, "source": [ "## Overview\n", "\n", "This tutorial demonstrates data augmentation: a technique to increase the diversity of your training set by applying random (but realistic) transformations, such as image rotation.\n", "\n", "You will learn how to apply data augmentation in two ways:\n", "\n", "- Use the Keras preprocessing layers, such as `tf.keras.layers.Resizing`, `tf.keras.layers.Rescaling`, `tf.keras.layers.RandomFlip`, and `tf.keras.layers.RandomRotation`.\n", "- Use the `tf.image` methods, such as `tf.image.flip_left_right`, `tf.image.rgb_to_grayscale`, `tf.image.adjust_brightness`, `tf.image.central_crop`, and `tf.image.stateless_random*`." ] }, { "cell_type": "markdown", "metadata": { "id": "-UxHAqXmSXN5" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C2Q5rPenTAJP" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "from tensorflow.keras import layers" ] }, { "cell_type": "markdown", "metadata": { "id": "Ydx3SSoF4wpG" }, "source": [ "## Download a dataset\n", "\n", "This tutorial uses the [tf_flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset. For convenience, download the dataset using [TensorFlow Datasets](https://www.tensorflow.org/datasets). If you would like to learn about other ways of importing data, check out the [load images](https://www.tensorflow.org/tutorials/load_data/images) tutorial.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ytHhsYmO52zy" }, "outputs": [], "source": [ "(train_ds, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "MjxEJtCwsnmm" }, "source": [ "The flowers dataset has five classes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wKwx7vQuspxz" }, "outputs": [], "source": [ "num_classes = metadata.features['label'].num_classes\n", "print(num_classes)" ] }, { "cell_type": "markdown", "metadata": { "id": "zZAQW44949uw" }, "source": [ "Let's retrieve an image from the dataset and use it to demonstrate data augmentation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kXlx1lCr5Bip" }, "outputs": [], "source": [ "get_label_name = metadata.features['label'].int2str\n", "\n", "image, label = next(iter(train_ds))\n", "_ = plt.imshow(image)\n", "_ = plt.title(get_label_name(label))" ] }, { "cell_type": "markdown", "metadata": { "id": "vdJ6XA4q2nqK" }, "source": [ "## Use Keras preprocessing layers" ] }, { "cell_type": "markdown", "metadata": { "id": "GRMPnfzBB2hw" }, "source": [ "### Resizing and rescaling\n" ] }, { "cell_type": "markdown", "metadata": { "id": "jhG7gSWmUMJx" }, "source": [ "You can use the Keras preprocessing layers to resize your images to a consistent shape (with `tf.keras.layers.Resizing`), and to rescale pixel values (with `tf.keras.layers.Rescaling`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jMM3b85e3yhd" }, "outputs": [], "source": [ "IMG_SIZE = 180\n", "\n", "resize_and_rescale = tf.keras.Sequential([\n", " layers.Resizing(IMG_SIZE, IMG_SIZE),\n", " layers.Rescaling(1./255)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "4z8AV1WgnYNW" }, "source": [ "Note: The rescaling layer above standardizes pixel values to the `[0, 1]` range. If instead you wanted it to be `[-1, 1]`, you would write `tf.keras.layers.Rescaling(1./127.5, offset=-1)`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MQiTwsHJDHAD" }, "source": [ "You can visualize the result of applying these layers to an image. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X9OLuR1bC1Pd" }, "outputs": [], "source": [ "result = resize_and_rescale(image)\n", "_ = plt.imshow(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "yxAMg8Zql5lw" }, "source": [ "Verify that the pixels are in the `[0, 1]` range:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DPTB8IQmSeKM" }, "outputs": [], "source": [ "print(\"Min and max pixel values:\", result.numpy().min(), result.numpy().max())" ] }, { "cell_type": "markdown", "metadata": { "id": "fL6M7fuivAw4" }, "source": [ "### Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "id": "SL4Suj46ScfU" }, "source": [ "You can use the Keras preprocessing layers for data augmentation as well, such as `tf.keras.layers.RandomFlip` and `tf.keras.layers.RandomRotation`." ] }, { "cell_type": "markdown", "metadata": { "id": "V-4PugTE-4sl" }, "source": [ "Let's create a few preprocessing layers and apply them repeatedly to the same image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Svu_5yfa_Jb7" }, "outputs": [], "source": [ "data_augmentation = tf.keras.Sequential([\n", " layers.RandomFlip(\"horizontal_and_vertical\"),\n", " layers.RandomRotation(0.2),\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kfzEuaNg69iU" }, "outputs": [], "source": [ "# Add the image to a batch.\n", "image = tf.cast(tf.expand_dims(image, 0), tf.float32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eR4wwi5Q_UZK" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " augmented_image = data_augmentation(image)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_image[0])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "jA17pEeS_2_-" }, "source": [ "There are a variety of preprocessing layers you can use for data augmentation including `tf.keras.layers.RandomContrast`, `tf.keras.layers.RandomCrop`, `tf.keras.layers.RandomZoom`, and others." ] }, { "cell_type": "markdown", "metadata": { "id": "GG5RhIJtE0ng" }, "source": [ "### Two options to use the Keras preprocessing layers\n", "\n", "There are two ways you can use these preprocessing layers, with important trade-offs." ] }, { "cell_type": "markdown", "metadata": { "id": "MxGvUT727Po6" }, "source": [ "#### Option 1: Make the preprocessing layers part of your model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ULGJQjP6hHvu" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " # Add the preprocessing layers you created earlier.\n", " resize_and_rescale,\n", " data_augmentation,\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " # Rest of your model.\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "pc6ELneyhJN9" }, "source": [ "There are two important points to be aware of in this case:\n", "\n", "* Data augmentation will run on-device, synchronously with the rest of your layers, and benefit from GPU acceleration.\n", "\n", "* When you export your model using `model.save`, the preprocessing layers will be saved along with the rest of your model. If you later deploy this model, it will automatically standardize images (according to the configuration of your layers). This can save you from the effort of having to reimplement that logic server-side." ] }, { "cell_type": "markdown", "metadata": { "id": "syZwDSpiRXZP" }, "source": [ "Note: Data augmentation is inactive at test time so input images will only be augmented during calls to `Model.fit` (not `Model.evaluate` or `Model.predict`)." ] }, { "cell_type": "markdown", "metadata": { "id": "B2X3JTeY_vfv" }, "source": [ "#### Option 2: Apply the preprocessing layers to your dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "r1Bt7w5VhVDY" }, "outputs": [], "source": [ "aug_ds = train_ds.map(\n", " lambda x, y: (resize_and_rescale(x, training=True), y))" ] }, { "cell_type": "markdown", "metadata": { "id": "HKqeahG2hVdV" }, "source": [ "With this approach, you use `Dataset.map` to create a dataset that yields batches of augmented images. In this case:\n", "\n", "* Data augmentation will happen asynchronously on the CPU, and is non-blocking. You can overlap the training of your model on the GPU with data preprocessing, using `Dataset.prefetch`, shown below.\n", "* In this case the preprocessing layers will not be exported with the model when you call `Model.save`. You will need to attach them to your model before saving it or reimplement them server-side. After training, you can attach the preprocessing layers before export.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "cgj51k9J7jfc" }, "source": [ "You can find an example of the first option in the [Image classification](classification.ipynb) tutorial. Let's demonstrate the second option here." ] }, { "cell_type": "markdown", "metadata": { "id": "31YwMQdrXKBP" }, "source": [ "### Apply the preprocessing layers to the datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "WUgW-2LOGiOT" }, "source": [ "Configure the training, validation, and test datasets with the Keras preprocessing layers you created earlier. You will also configure the datasets for performance, using parallel reads and buffered prefetching to yield batches from disk without I/O become blocking. (Learn more dataset performance in the [Better performance with the tf.data API](https://www.tensorflow.org/guide/data_performance) guide.)" ] }, { "cell_type": "markdown", "metadata": { "id": "eI7VdyqK767y" }, "source": [ "Note: Data augmentation should only be applied to the training set." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R5fGVMqlFxF7" }, "outputs": [], "source": [ "batch_size = 32\n", "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "def prepare(ds, shuffle=False, augment=False):\n", " # Resize and rescale all datasets.\n", " ds = ds.map(lambda x, y: (resize_and_rescale(x), y), \n", " num_parallel_calls=AUTOTUNE)\n", "\n", " if shuffle:\n", " ds = ds.shuffle(1000)\n", "\n", " # Batch all datasets.\n", " ds = ds.batch(batch_size)\n", "\n", " # Use data augmentation only on the training set.\n", " if augment:\n", " ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), \n", " num_parallel_calls=AUTOTUNE)\n", "\n", " # Use buffered prefetching on all datasets.\n", " return ds.prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N86SFGMBHcx-" }, "outputs": [], "source": [ "train_ds = prepare(train_ds, shuffle=True, augment=True)\n", "val_ds = prepare(val_ds)\n", "test_ds = prepare(test_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "9gplDz4ZV6kk" }, "source": [ "### Train a model\n", "\n", "For completeness, you will now train a model using the datasets you have just prepared.\n", "\n", "The [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) model consists of three convolution blocks (`tf.keras.layers.Conv2D`) with a max pooling layer (`tf.keras.layers.MaxPooling2D`) in each of them. There's a fully-connected layer (`tf.keras.layers.Dense`) with 128 units on top of it that is activated by a ReLU activation function (`'relu'`). This model has not been tuned for accuracy (the goal is to show you the mechanics)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IODSymGhq9N6" }, "outputs": [], "source": [ "model = tf.keras.Sequential([\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "86454855f7d9" }, "source": [ "Choose the `tf.keras.optimizers.Adam` optimizer and `tf.keras.losses.SparseCategoricalCrossentropy` loss function. To view training and validation accuracy for each training epoch, pass the `metrics` argument to `Model.compile`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZnRJr95WY68k" }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "976f718cabc8" }, "source": [ "Train for a few epochs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i_sDl9uZY9Mh" }, "outputs": [], "source": [ "epochs=5\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V9PSf4qgiQJG" }, "outputs": [], "source": [ "loss, acc = model.evaluate(test_ds)\n", "print(\"Accuracy\", acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "0BkRvvsXb6SI" }, "source": [ "### Custom data augmentation\n", "\n", "You can also create custom data augmentation layers.\n", "\n", "This section of the tutorial shows two ways of doing so:\n", "\n", "- First, you will create a `tf.keras.layers.Lambda` layer. This is a good way to write concise code.\n", "- Next, you will write a new layer via [subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models), which gives you more control.\n", "\n", "Both layers will randomly invert the colors in an image, according to some probability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nMxEhIVXmAH0" }, "outputs": [], "source": [ "def random_invert_img(x, p=0.5):\n", " if tf.random.uniform([]) < p:\n", " x = (255-x)\n", " else:\n", " x\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C0huNpxdmDKu" }, "outputs": [], "source": [ "def random_invert(factor=0.5):\n", " return layers.Lambda(lambda x: random_invert_img(x, factor))\n", "\n", "random_invert = random_invert()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wAcOluP0TNG6" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " augmented_image = random_invert(image)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_image[0].numpy().astype(\"uint8\"))\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Xd9XG2PLM5ZJ" }, "source": [ "Next, implement a custom layer by [subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d11eExc-Ke-7" }, "outputs": [], "source": [ "class RandomInvert(layers.Layer):\n", " def __init__(self, factor=0.5, **kwargs):\n", " super().__init__(**kwargs)\n", " self.factor = factor\n", "\n", " def call(self, x):\n", " return random_invert_img(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qX-VQgkRL6fc" }, "outputs": [], "source": [ "_ = plt.imshow(RandomInvert()(image)[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "B0nmllnXZO6T" }, "source": [ "Both of these layers can be used as described in options 1 and 2 above." ] }, { "cell_type": "markdown", "metadata": { "id": "j7-k__2dAfX6" }, "source": [ "## Using tf.image" ] }, { "cell_type": "markdown", "metadata": { "id": "NJco2x35EAMs" }, "source": [ "The above Keras preprocessing utilities are convenient. But, for finer control, you can write your own data augmentation pipelines or layers using `tf.data` and `tf.image`. (You may also want to check out [TensorFlow Addons Image: Operations](https://www.tensorflow.org/addons/tutorials/image_ops) and [TensorFlow I/O: Color Space Conversions](https://www.tensorflow.org/io/tutorials/colorspace).)" ] }, { "cell_type": "markdown", "metadata": { "id": "xR1RvjYkdd_i" }, "source": [ "Since the flowers dataset was previously configured with data augmentation, let's reimport it to start fresh:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JB-lAS0z9ZJY" }, "outputs": [], "source": [ "(train_ds, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "rQ3pqBTS9hNj" }, "source": [ "Retrieve an image to work with:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dDsPaAi8de_j" }, "outputs": [], "source": [ "image, label = next(iter(train_ds))\n", "_ = plt.imshow(image)\n", "_ = plt.title(get_label_name(label))" ] }, { "cell_type": "markdown", "metadata": { "id": "chelxcPtFiTF" }, "source": [ "Let's use the following function to visualize and compare the original and augmented images side-by-side:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sN1ykjJCHikc" }, "outputs": [], "source": [ "def visualize(original, augmented):\n", " fig = plt.figure()\n", " plt.subplot(1,2,1)\n", " plt.title('Original image')\n", " plt.imshow(original)\n", "\n", " plt.subplot(1,2,2)\n", " plt.title('Augmented image')\n", " plt.imshow(augmented)" ] }, { "cell_type": "markdown", "metadata": { "id": "C5X4ijQYHmlt" }, "source": [ "### Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "id": "RRD9oujLHo6c" }, "source": [ "#### Flip an image\n", "\n", "Flip an image either vertically or horizontally with `tf.image.flip_left_right`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1ZjVI24nIH0S" }, "outputs": [], "source": [ "flipped = tf.image.flip_left_right(image)\n", "visualize(image, flipped)" ] }, { "cell_type": "markdown", "metadata": { "id": "6iD_lLibIL9q" }, "source": [ "#### Grayscale an image\n", "\n", "You can grayscale an image with `tf.image.rgb_to_grayscale`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ikaMj0guIRtL" }, "outputs": [], "source": [ "grayscaled = tf.image.rgb_to_grayscale(image)\n", "visualize(image, tf.squeeze(grayscaled))\n", "_ = plt.colorbar()" ] }, { "cell_type": "markdown", "metadata": { "id": "f-5yjIs4IZ7v" }, "source": [ "#### Saturate an image\n", "\n", "Saturate an image with `tf.image.adjust_saturation` by providing a saturation factor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PHz-NosiInmz" }, "outputs": [], "source": [ "saturated = tf.image.adjust_saturation(image, 3)\n", "visualize(image, saturated)" ] }, { "cell_type": "markdown", "metadata": { "id": "FWXiy8qfIqdC" }, "source": [ "#### Change image brightness\n", "\n", "Change the brightness of image with `tf.image.adjust_brightness` by providing a brightness factor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1hdG-j46I0nJ" }, "outputs": [], "source": [ "bright = tf.image.adjust_brightness(image, 0.4)\n", "visualize(image, bright)" ] }, { "cell_type": "markdown", "metadata": { "id": "vjEOFEITJOr2" }, "source": [ "#### Center crop an image\n", "\n", "Crop the image from center up to the image part you desire using `tf.image.central_crop`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RWkK5GFHJUKT" }, "outputs": [], "source": [ "cropped = tf.image.central_crop(image, central_fraction=0.5)\n", "visualize(image, cropped)" ] }, { "cell_type": "markdown", "metadata": { "id": "unt76GebI3Gc" }, "source": [ "#### Rotate an image\n", "\n", "Rotate an image by 90 degrees with `tf.image.rot90`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b19KuAhkJKR-" }, "outputs": [], "source": [ "rotated = tf.image.rot90(image)\n", "visualize(image, rotated)" ] }, { "cell_type": "markdown", "metadata": { "id": "5CPP0vEKB56X" }, "source": [ "### Random transformations\n", "\n", "Warning: There are two sets of random image operations: `tf.image.random*` and `tf.image.stateless_random*`. Using `tf.image.random*` operations is strongly discouraged as they use the old RNGs from TF 1.x. Instead, please use the random image operations introduced in this tutorial. For more information, refer to [Random number generation](../../guide/random_numbers.ipynb).\n", "\n", "Applying random transformations to the images can further help generalize and expand the dataset. The current `tf.image` API provides eight such random image operations (ops):\n", "\n", "* [`tf.image.stateless_random_brightness`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_brightness)\n", "* [`tf.image.stateless_random_contrast`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_contrast)\n", "* [`tf.image.stateless_random_crop`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_crop)\n", "* [`tf.image.stateless_random_flip_left_right`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_flip_left_right)\n", "* [`tf.image.stateless_random_flip_up_down`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_flip_up_down)\n", "* [`tf.image.stateless_random_hue`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_hue)\n", "* [`tf.image.stateless_random_jpeg_quality`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_jpeg_quality)\n", "* [`tf.image.stateless_random_saturation`](https://www.tensorflow.org/api_docs/python/tf/image/stateless_random_saturation)\n", "\n", "These random image ops are purely functional: the output only depends on the input. This makes them simple to use in high performance, deterministic input pipelines. They require a `seed` value be input each step. Given the same `seed`, they return the same results independent of how many times they are called.\n", "\n", "Note: `seed` is a `Tensor` of shape `(2,)` whose values are any integers.\n", "\n", "In the following sections, you will:\n", "1. Go over examples of using random image operations to transform an image.\n", "2. Demonstrate how to apply random transformations to a training dataset." ] }, { "cell_type": "markdown", "metadata": { "id": "251Wy-MqE4La" }, "source": [ "#### Randomly change image brightness\n", "\n", "Randomly change the brightness of `image` using `tf.image.stateless_random_brightness` by providing a brightness factor and `seed`. The brightness factor is chosen randomly in the range `[-max_delta, max_delta)` and is associated with the given `seed`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-fFd1kh7Fr-_" }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_brightness = tf.image.stateless_random_brightness(\n", " image, max_delta=0.95, seed=seed)\n", " visualize(image, stateless_random_brightness)" ] }, { "cell_type": "markdown", "metadata": { "id": "uLaDEmooUfYJ" }, "source": [ "#### Randomly change image contrast\n", "\n", "Randomly change the contrast of `image` using `tf.image.stateless_random_contrast` by providing a contrast range and `seed`. The contrast range is chosen randomly in the interval `[lower, upper]` and is associated with the given `seed`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GmcYoQHaUoke" }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_contrast = tf.image.stateless_random_contrast(\n", " image, lower=0.1, upper=0.9, seed=seed)\n", " visualize(image, stateless_random_contrast)" ] }, { "cell_type": "markdown", "metadata": { "id": "wxb-MP-KVPNz" }, "source": [ "#### Randomly crop an image\n", "\n", "Randomly crop `image` using `tf.image.stateless_random_crop` by providing target `size` and `seed`. The portion that gets cropped out of `image` is at a randomly chosen offset and is associated with the given `seed`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vtZQbUw0VOm5" }, "outputs": [], "source": [ "for i in range(3):\n", " seed = (i, 0) # tuple of size (2,)\n", " stateless_random_crop = tf.image.stateless_random_crop(\n", " image, size=[210, 300, 3], seed=seed)\n", " visualize(image, stateless_random_crop)" ] }, { "cell_type": "markdown", "metadata": { "id": "isrM-MZtpxTq" }, "source": [ "### Apply augmentation to a dataset\n", "\n", "Let's first download the image dataset again in case they are modified in the previous sections." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xC80NQP809Uo" }, "outputs": [], "source": [ "(train_datasets, val_ds, test_ds), metadata = tfds.load(\n", " 'tf_flowers',\n", " split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],\n", " with_info=True,\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SMo9HTDV0Gaz" }, "source": [ "Next, define a utility function for resizing and rescaling the images. This function will be used in unifying the size and scale of images in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1JKmx06lfcFr" }, "outputs": [], "source": [ "def resize_and_rescale(image, label):\n", " image = tf.cast(image, tf.float32)\n", " image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])\n", " image = (image / 255.0)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "M7OpE_-jWq-I" }, "source": [ "Let's also define the `augment` function that can apply the random transformations to the images. This function will be used on the dataset in the next step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KitLdvlpVxPa" }, "outputs": [], "source": [ "def augment(image_label, seed):\n", " image, label = image_label\n", " image, label = resize_and_rescale(image, label)\n", " image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)\n", " # Make a new seed.\n", " new_seed = tf.random.split(seed, num=1)[0, :]\n", " # Random crop back to the original size.\n", " image = tf.image.stateless_random_crop(\n", " image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)\n", " # Random brightness.\n", " image = tf.image.stateless_random_brightness(\n", " image, max_delta=0.5, seed=new_seed)\n", " image = tf.clip_by_value(image, 0, 1)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "SlXRsVp70hg8" }, "source": [ "#### Option 1: Using tf.data.experimental.Counter\n", "\n", "Create a `tf.data.experimental.Counter` object (let's call it `counter`) and `Dataset.zip` the dataset with `(counter, counter)`. This will ensure that each image in the dataset gets associated with a unique value (of shape `(2,)`) based on `counter` which later can get passed into the `augment` function as the `seed` value for random transformations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SZ6Qq0IWznfi" }, "outputs": [], "source": [ "# Create a `Counter` object and `Dataset.zip` it together with the training set.\n", "counter = tf.data.experimental.Counter()\n", "train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))" ] }, { "cell_type": "markdown", "metadata": { "id": "eF9ybVQ94X9f" }, "source": [ "Map the `augment` function to the training dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wQK9BDKk1_3N" }, "outputs": [], "source": [ "train_ds = (\n", " train_ds\n", " .shuffle(1000)\n", " .map(augment, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3AQoyA-k3ELk" }, "outputs": [], "source": [ "val_ds = (\n", " val_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p2IQN3NN3G_M" }, "outputs": [], "source": [ "test_ds = (\n", " test_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "pvTVY8BY2LpD" }, "source": [ "#### Option 2: Using tf.random.Generator\n", "\n", "- Create a `tf.random.Generator` object with an initial `seed` value. Calling the `make_seeds` function on the same generator object always returns a new, unique `seed` value.\n", "- Define a wrapper function that: 1) calls the `make_seeds` function; and 2) passes the newly generated `seed` value into the `augment` function for random transformations.\n", "\n", "Note: `tf.random.Generator` objects store RNG state in a `tf.Variable`, which means it can be saved as a [checkpoint](../../guide/checkpoint.ipynb) or in a [SavedModel](../../guide/saved_model.ipynb). For more details, please refer to [Random number generation](../../guide/random_numbers.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BQDvedZ33eAy" }, "outputs": [], "source": [ "# Create a generator.\n", "rng = tf.random.Generator.from_seed(123, alg='philox')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eDEkO1nt2ta0" }, "outputs": [], "source": [ "# Create a wrapper function for updating seeds.\n", "def f(x, y):\n", " seed = rng.make_seeds(1)[:, 0]\n", " image, label = augment((x, y), seed)\n", " return image, label" ] }, { "cell_type": "markdown", "metadata": { "id": "PyPC4vUM4MT0" }, "source": [ "Map the wrapper function `f` to the training dataset, and the `resize_and_rescale` function—to the validation and test sets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pu2uB7k12xKw" }, "outputs": [], "source": [ "train_ds = (\n", " train_datasets\n", " .shuffle(1000)\n", " .map(f, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "e6caldPi2HAP" }, "outputs": [], "source": [ "val_ds = (\n", " val_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ceaCdJnh2I-r" }, "outputs": [], "source": [ "test_ds = (\n", " test_ds\n", " .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)\n", " .batch(batch_size)\n", " .prefetch(AUTOTUNE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "hKwCA6AOjTrc" }, "source": [ "These datasets can now be used to train a model as shown previously." ] }, { "cell_type": "markdown", "metadata": { "id": "YypDihDlj0no" }, "source": [ "## Next steps\n", "\n", "This tutorial demonstrated data augmentation using Keras preprocessing layers and `tf.image`.\n", "\n", "- To learn how to include preprocessing layers inside your model, refer to the [Image classification](classification.ipynb) tutorial.\n", "- You may also be interested in learning how preprocessing layers can help you classify text, as shown in the [Basic text classification](../keras/text_classification.ipynb) tutorial.\n", "- You can learn more about `tf.data` in this [guide](../../guide/data.ipynb), and you can learn how to configure your input pipelines for performance [here](../../guide/data_performance.ipynb)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "data_augmentation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }