{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "cZCM65CBt1CJ" }, "source": [ "##### Copyright 2019 The TensorFlow Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "JOgMcEajtkmg" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "rCSP-dbMw88x" }, "source": [ "# Image segmentation" ] }, { "cell_type": "markdown", "metadata": { "id": "NEWs8JXRuGex" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "sMP7mglMuGT2" }, "source": [ "This tutorial focuses on the task of image segmentation, using a modified U-Net.\n", "\n", "## What is image segmentation?\n", "\n", "In an image classification task, the network assigns a label (or class) to each input image. However, suppose you want to know the shape of that object, which pixel belongs to which object, etc. In this case, you need to assign a class to each pixel of the image—this task is known as segmentation. A segmentation model returns much more detailed information about the image. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging, just to name a few.\n", "\n", "This tutorial uses the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) ([Parkhi et al, 2012](https://www.robots.ox.ac.uk/~vgg/publications/2012/parkhi12a/parkhi12a.pdf)). The dataset consists of images of 37 pet breeds, with 200 images per breed (~100 each in the training and test splits). Each image includes the corresponding labels, and pixel-wise masks. The masks are class-labels for each pixel. Each pixel is given one of three categories:\n", "\n", "- Class 1: Pixel belonging to the pet.\n", "- Class 2: Pixel bordering the pet.\n", "- Class 3: None of the above/a surrounding pixel." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MQmKthrSBCld" }, "outputs": [], "source": [ "!pip install git+https://github.com/tensorflow/examples.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQX7R4bhZy5h" }, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g87--n2AtyO_" }, "outputs": [], "source": [ "from tensorflow_examples.models.pix2pix import pix2pix\n", "\n", "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "oWe0_rQM4JbC" }, "source": [ "## Download the Oxford-IIIT Pets dataset\n", "\n", "The dataset is [available from TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet). The segmentation masks are included in version 3+." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "40ITeStwDwZb" }, "outputs": [], "source": [ "dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "rJcVdj_U4vzf" }, "source": [ " In addition, the image color values are normalized to the `[0, 1]` range. Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FD60EbcAQqov" }, "outputs": [], "source": [ "def normalize(input_image, input_mask):\n", " input_image = tf.cast(input_image, tf.float32) / 255.0\n", " input_mask -= 1\n", " return input_image, input_mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zf0S67hJRp3D" }, "outputs": [], "source": [ "def load_image(datapoint):\n", " input_image = tf.image.resize(datapoint['image'], (128, 128))\n", " input_mask = tf.image.resize(\n", " datapoint['segmentation_mask'],\n", " (128, 128),\n", " method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,\n", " )\n", "\n", " input_image, input_mask = normalize(input_image, input_mask)\n", "\n", " return input_image, input_mask" ] }, { "cell_type": "markdown", "metadata": { "id": "65-qHTjX5VZh" }, "source": [ "The dataset already contains the required training and test splits, so continue to use the same splits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yHwj2-8SaQli" }, "outputs": [], "source": [ "TRAIN_LENGTH = info.splits['train'].num_examples\n", "BATCH_SIZE = 64\n", "BUFFER_SIZE = 1000\n", "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "39fYScNz9lmo" }, "outputs": [], "source": [ "train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)\n", "test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "T9hGHyg8L3Y1" }, "source": [ "The following class performs a simple augmentation by randomly-flipping an image.\n", "Go to the [Image augmentation](data_augmentation.ipynb) tutorial to learn more.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fUWdDJRTL0PP" }, "outputs": [], "source": [ "class Augment(tf.keras.layers.Layer):\n", " def __init__(self, seed=42):\n", " super().__init__()\n", " # both use the same seed, so they'll make the same random changes.\n", " self.augment_inputs = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n", " self.augment_labels = tf.keras.layers.RandomFlip(mode=\"horizontal\", seed=seed)\n", " \n", " def call(self, inputs, labels):\n", " inputs = self.augment_inputs(inputs)\n", " labels = self.augment_labels(labels)\n", " return inputs, labels" ] }, { "cell_type": "markdown", "metadata": { "id": "xTIbNIBdcgL3" }, "source": [ "Build the input pipeline, applying the augmentation after batching the inputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VPscskQcNCx4" }, "outputs": [], "source": [ "train_batches = (\n", " train_images\n", " .cache()\n", " .shuffle(BUFFER_SIZE)\n", " .batch(BATCH_SIZE)\n", " .repeat()\n", " .map(Augment())\n", " .prefetch(buffer_size=tf.data.AUTOTUNE))\n", "\n", "test_batches = test_images.batch(BATCH_SIZE)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xa3gMAE_9qNa" }, "source": [ "Visualize an image example and its corresponding mask from the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3N2RPAAW9q4W" }, "outputs": [], "source": [ "def display(display_list):\n", " plt.figure(figsize=(15, 15))\n", "\n", " title = ['Input Image', 'True Mask', 'Predicted Mask']\n", "\n", " for i in range(len(display_list)):\n", " plt.subplot(1, len(display_list), i+1)\n", " plt.title(title[i])\n", " plt.imshow(tf.keras.utils.array_to_img(display_list[i]))\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "a6u_Rblkteqb" }, "outputs": [], "source": [ "for images, masks in train_batches.take(2):\n", " sample_image, sample_mask = images[0], masks[0]\n", " display([sample_image, sample_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "FAOe93FRMk3w" }, "source": [ "## Define the model\n", "The model being used here is a modified [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). To learn robust features and reduce the number of trainable parameters, use a pretrained model—[MobileNetV2](https://arxiv.org/abs/1801.04381)—as the encoder. For the decoder, you will use the upsample block, which is already implemented in the [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) example in the TensorFlow Examples repo. (Check out the [pix2pix: Image-to-image translation with a conditional GAN](../generative/pix2pix.ipynb) tutorial in a notebook.)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "W4mQle3lthit" }, "source": [ "As mentioned, the encoder is a pretrained MobileNetV2 model. You will use the model from `tf.keras.applications`. The encoder consists of specific outputs from intermediate layers in the model. Note that the encoder will not be trained during the training process." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "liCeLH0ctjq7" }, "outputs": [], "source": [ "base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)\n", "\n", "# Use the activations of these layers\n", "layer_names = [\n", " 'block_1_expand_relu', # 64x64\n", " 'block_3_expand_relu', # 32x32\n", " 'block_6_expand_relu', # 16x16\n", " 'block_13_expand_relu', # 8x8\n", " 'block_16_project', # 4x4\n", "]\n", "base_model_outputs = [base_model.get_layer(name).output for name in layer_names]\n", "\n", "# Create the feature extraction model\n", "down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)\n", "\n", "down_stack.trainable = False" ] }, { "cell_type": "markdown", "metadata": { "id": "KPw8Lzra5_T9" }, "source": [ "The decoder/upsampler is simply a series of upsample blocks implemented in TensorFlow examples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p0ZbfywEbZpJ" }, "outputs": [], "source": [ "up_stack = [\n", " pix2pix.upsample(512, 3), # 4x4 -> 8x8\n", " pix2pix.upsample(256, 3), # 8x8 -> 16x16\n", " pix2pix.upsample(128, 3), # 16x16 -> 32x32\n", " pix2pix.upsample(64, 3), # 32x32 -> 64x64\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45HByxpVtrPF" }, "outputs": [], "source": [ "def unet_model(output_channels:int):\n", " inputs = tf.keras.layers.Input(shape=[128, 128, 3])\n", "\n", " # Downsampling through the model\n", " skips = down_stack(inputs)\n", " x = skips[-1]\n", " skips = reversed(skips[:-1])\n", "\n", " # Upsampling and establishing the skip connections\n", " for up, skip in zip(up_stack, skips):\n", " x = up(x)\n", " concat = tf.keras.layers.Concatenate()\n", " x = concat([x, skip])\n", "\n", " # This is the last layer of the model\n", " last = tf.keras.layers.Conv2DTranspose(\n", " filters=output_channels, kernel_size=3, strides=2,\n", " padding='same') #64x64 -> 128x128\n", "\n", " x = last(x)\n", "\n", " return tf.keras.Model(inputs=inputs, outputs=x)" ] }, { "cell_type": "markdown", "metadata": { "id": "LRsjdZuEnZfA" }, "source": [ "Note that the number of filters on the last layer is set to the number of `output_channels`. This will be one output channel per class." ] }, { "cell_type": "markdown", "metadata": { "id": "j0DGH_4T0VYn" }, "source": [ "## Train the model\n", "\n", "Now, all that is left to do is to compile and train the model. \n", "\n", "Since this is a multiclass classification problem, use the `tf.keras.losses.SparseCategoricalCrossentropy` loss function with the `from_logits` argument set to `True`, since the labels are scalar integers instead of vectors of scores for each pixel of every class.\n", "\n", "When running inference, the label assigned to the pixel is the channel with the highest value. This is what the `create_mask` function is doing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6he36HK5uKAc" }, "outputs": [], "source": [ "OUTPUT_CLASSES = 3\n", "\n", "model = unet_model(output_channels=OUTPUT_CLASSES)\n", "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "xVMzbIZLcyEF" }, "source": [ "Plot the resulting model architecture:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sw82qF1Gcovr" }, "outputs": [], "source": [ "tf.keras.utils.plot_model(model, show_shapes=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Tc3MiEO2twLS" }, "source": [ "Try out the model to check what it predicts before training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UwvIKLZPtxV_" }, "outputs": [], "source": [ "def create_mask(pred_mask):\n", " pred_mask = tf.math.argmax(pred_mask, axis=-1)\n", " pred_mask = pred_mask[..., tf.newaxis]\n", " return pred_mask[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLNsrynNtx4d" }, "outputs": [], "source": [ "def show_predictions(dataset=None, num=1):\n", " if dataset:\n", " for image, mask in dataset.take(num):\n", " pred_mask = model.predict(image)\n", " display([image[0], mask[0], create_mask(pred_mask)])\n", " else:\n", " display([sample_image, sample_mask,\n", " create_mask(model.predict(sample_image[tf.newaxis, ...]))])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X_1CC0T4dho3" }, "outputs": [], "source": [ "show_predictions()" ] }, { "cell_type": "markdown", "metadata": { "id": "22AyVYWQdkgk" }, "source": [ "The callback defined below is used to observe how the model improves while it is training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wHrHsqijdmL6" }, "outputs": [], "source": [ "class DisplayCallback(tf.keras.callbacks.Callback):\n", " def on_epoch_end(self, epoch, logs=None):\n", " clear_output(wait=True)\n", " show_predictions()\n", " print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "StKDH_B9t4SD" }, "outputs": [], "source": [ "EPOCHS = 20\n", "VAL_SUBSPLITS = 5\n", "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n", "\n", "model_history = model.fit(train_batches, epochs=EPOCHS,\n", " steps_per_epoch=STEPS_PER_EPOCH,\n", " validation_steps=VALIDATION_STEPS,\n", " validation_data=test_batches,\n", " callbacks=[DisplayCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P_mu0SAbt40Q" }, "outputs": [], "source": [ "loss = model_history.history['loss']\n", "val_loss = model_history.history['val_loss']\n", "\n", "plt.figure()\n", "plt.plot(model_history.epoch, loss, 'r', label='Training loss')\n", "plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')\n", "plt.title('Training and Validation Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss Value')\n", "plt.ylim([0, 1])\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "unP3cnxo_N72" }, "source": [ "## Make predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "7BVXldSo-0mW" }, "source": [ "Now, make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ikrzoG24qwf5" }, "outputs": [], "source": [ "show_predictions(test_batches, 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "QAwvlgSNoK3o" }, "source": [ "## Optional: Imbalanced classes and class weights" ] }, { "cell_type": "markdown", "metadata": { "id": "eqtFPqqu2kxP" }, "source": [ "Semantic segmentation datasets can be highly imbalanced meaning that particular class pixels can be present more inside images than that of other classes. Since segmentation problems can be treated as per-pixel classification problems, you can deal with the imbalance problem by weighing the loss function to account for this. It's a simple and elegant way to deal with this problem. Refer to the [Classification on imbalanced data](../structured_data/imbalanced_data.ipynb) tutorial to learn more.\n", "\n", "To [avoid ambiguity](https://github.com/keras-team/keras/issues/3653#issuecomment-243939748), `Model.fit` does not support the `class_weight` argument for targets with 3+ dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aHt90UEQsZDn" }, "outputs": [], "source": [ "try:\n", " model_history = model.fit(train_batches, epochs=EPOCHS,\n", " steps_per_epoch=STEPS_PER_EPOCH,\n", " class_weight = {0:2.0, 1:2.0, 2:1.0})\n", " assert False\n", "except Exception as e:\n", " print(f\"Expected {type(e).__name__}: {e}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "brbhYODCsvbe" }, "source": [ "So, in this case you need to implement the weighting yourself. You'll do this using sample weights: In addition to `(data, label)` pairs, `Model.fit` also accepts `(data, label, sample_weight)` triples.\n", "\n", "Keras `Model.fit` propagates the `sample_weight` to the losses and metrics, which also accept a `sample_weight` argument. The sample weight is multiplied by the sample's value before the reduction step. For example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EmHtImJn5Kk-" }, "outputs": [], "source": [ "label = [0,0]\n", "prediction = [[-3., 0], [-3, 0]] \n", "sample_weight = [1, 10] \n", "\n", "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,\n", " reduction=tf.keras.losses.Reduction.NONE)\n", "loss(label, prediction, sample_weight).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "Gbwo3DZ-9TxM" }, "source": [ "So, to make sample weights for this tutorial, you need a function that takes a `(data, label)` pair and returns a `(data, label, sample_weight)` triple where the `sample_weight` is a 1-channel image containing the class weight for each pixel.\n", "\n", "The simplest possible implementation is to use the label as an index into a `class_weight` list:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DlG-n2Ugo8Jc" }, "outputs": [], "source": [ "def add_sample_weights(image, label):\n", " # The weights for each class, with the constraint that:\n", " # sum(class_weights) == 1.0\n", " class_weights = tf.constant([2.0, 2.0, 1.0])\n", " class_weights = class_weights/tf.reduce_sum(class_weights)\n", "\n", " # Create an image of `sample_weights` by using the label at each pixel as an \n", " # index into the `class weights` .\n", " sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))\n", "\n", " return image, label, sample_weights" ] }, { "cell_type": "markdown", "metadata": { "id": "hLH_NvH2UrXU" }, "source": [ "The resulting dataset elements contain 3 images each:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SE_ezRSFRCnE" }, "outputs": [], "source": [ "train_batches.map(add_sample_weights).element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "Yc-EpIzaRbSL" }, "source": [ "Now, you can train a model on this weighted dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QDWipedAoOQe" }, "outputs": [], "source": [ "weighted_model = unet_model(OUTPUT_CLASSES)\n", "weighted_model.compile(\n", " optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "btEFKc1xodGR" }, "outputs": [], "source": [ "weighted_model.fit(\n", " train_batches.map(add_sample_weights),\n", " epochs=1,\n", " steps_per_epoch=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "R24tahEqmSCk" }, "source": [ "## Next steps\n", "\n", "Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained models. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.\n", "\n", "You may also want to see the [Tensorflow Object Detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/README.md) for another model you can retrain on your own data. Pretrained models are available on [TensorFlow Hub](https://www.tensorflow.org/hub/tutorials/tf2_object_detection#optional)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "segmentation.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }