{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "TBFXQGKYUc4X" }, "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "1z4xy2gTUc4a" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "FE7KNzPPVrVV" }, "source": [ "# Image classification" ] }, { "cell_type": "markdown", "metadata": { "id": "KwQtSOz0VrVX" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "gN7G9GFmVrVY" }, "source": [ "This tutorial shows how to classify images of flowers using a `tf.keras.Sequential` model and load data using `tf.keras.utils.image_dataset_from_directory`. It demonstrates the following concepts:\n", "\n", "\n", "* Efficiently loading a dataset off disk.\n", "* Identifying overfitting and applying techniques to mitigate it, including data augmentation and dropout.\n", "\n", "This tutorial follows a basic machine learning workflow:\n", "\n", "1. Examine and understand data\n", "2. Build an input pipeline\n", "3. Build the model\n", "4. Train the model\n", "5. Test the model\n", "6. Improve the model and repeat the process\n", "\n", "In addition, the notebook demonstrates how to convert a [saved model](../../../guide/saved_model.ipynb) to a [TensorFlow Lite](https://www.tensorflow.org/lite/) model for on-device machine learning on mobile, embedded, and IoT devices." ] }, { "cell_type": "markdown", "metadata": { "id": "zF9uvbXNVrVY" }, "source": [ "## Setup\n", "\n", "Import TensorFlow and other necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L1WtoaOHVrVh" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import PIL\n", "import tensorflow as tf\n", "\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.models import Sequential" ] }, { "cell_type": "markdown", "metadata": { "id": "UZZI6lNkVrVm" }, "source": [ "## Download and explore the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "DPHx8-t-VrVo" }, "source": [ "This tutorial uses a dataset of about 3,700 photos of flowers. The dataset contains five sub-directories, one per class:\n", "\n", "```\n", "flower_photo/\n", " daisy/\n", " dandelion/\n", " roses/\n", " sunflowers/\n", " tulips/\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "57CcilYSG0zv" }, "outputs": [], "source": [ "import pathlib\n", "\n", "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n", "data_dir = tf.keras.utils.get_file('flower_photos.tar', origin=dataset_url, extract=True)\n", "data_dir = pathlib.Path(data_dir).with_suffix('')" ] }, { "cell_type": "markdown", "metadata": { "id": "VpmywIlsVrVx" }, "source": [ "After downloading, you should now have a copy of the dataset available. There are 3,670 total images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SbtTDYhOHZb6" }, "outputs": [], "source": [ "image_count = len(list(data_dir.glob('*/*.jpg')))\n", "print(image_count)" ] }, { "cell_type": "markdown", "metadata": { "id": "PVmwkOSdHZ5A" }, "source": [ "Here are some roses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N1loMlbYHeiJ" }, "outputs": [], "source": [ "roses = list(data_dir.glob('roses/*'))\n", "PIL.Image.open(str(roses[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RQbZBOTLHiUP" }, "outputs": [], "source": [ "PIL.Image.open(str(roses[1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "DGEqiBbRHnyI" }, "source": [ "And some tulips:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HyQkfPGdHilw" }, "outputs": [], "source": [ "tulips = list(data_dir.glob('tulips/*'))\n", "PIL.Image.open(str(tulips[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wtlhWJPAHivf" }, "outputs": [], "source": [ "PIL.Image.open(str(tulips[1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "gIjgz7_JIo_m" }, "source": [ "## Load data using a Keras utility\n", "\n", "Next, load these images off disk using the helpful `tf.keras.utils.image_dataset_from_directory` utility. This will take you from a directory of images on disk to a `tf.data.Dataset` in just a couple lines of code. If you like, you can also write your own data loading code from scratch by visiting the [Load and preprocess images](../load_data/images.ipynb) tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "xyDNn9MbIzfT" }, "source": [ "### Create a dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "anqiK_AGI086" }, "source": [ "Define some parameters for the loader:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H74l2DoDI2XD" }, "outputs": [], "source": [ "batch_size = 32\n", "img_height = 180\n", "img_width = 180" ] }, { "cell_type": "markdown", "metadata": { "id": "pFBhRrrEI49z" }, "source": [ "It's good practice to use a validation split when developing your model. Use 80% of the images for training and 20% for validation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fIR0kRZiI_AT" }, "outputs": [], "source": [ "train_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"training\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iscU3UoVJBXj" }, "outputs": [], "source": [ "val_ds = tf.keras.utils.image_dataset_from_directory(\n", " data_dir,\n", " validation_split=0.2,\n", " subset=\"validation\",\n", " seed=123,\n", " image_size=(img_height, img_width),\n", " batch_size=batch_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "WLQULyAvJC3X" }, "source": [ "You can find the class names in the `class_names` attribute on these datasets. These correspond to the directory names in alphabetical order." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZHAxkHX5JD3k" }, "outputs": [], "source": [ "class_names = train_ds.class_names\n", "print(class_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "_uoVvxSLJW9m" }, "source": [ "## Visualize the data\n", "\n", "Here are the first nine images from the training dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wBmEA9c0JYes" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(10, 10))\n", "for images, labels in train_ds.take(1):\n", " for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(images[i].numpy().astype(\"uint8\"))\n", " plt.title(class_names[labels[i]])\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "5M6BXtXFJdW0" }, "source": [ "You will pass these datasets to the Keras `Model.fit` method for training later in this tutorial. If you like, you can also manually iterate over the dataset and retrieve batches of images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2-MfMoenJi8s" }, "outputs": [], "source": [ "for image_batch, labels_batch in train_ds:\n", " print(image_batch.shape)\n", " print(labels_batch.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "Wj4FrKxxJkoW" }, "source": [ "The `image_batch` is a tensor of the shape `(32, 180, 180, 3)`. This is a batch of 32 images of shape `180x180x3` (the last dimension refers to color channels RGB). The `label_batch` is a tensor of the shape `(32,)`, these are corresponding labels to the 32 images.\n", "\n", "You can call `.numpy()` on the `image_batch` and `labels_batch` tensors to convert them to a `numpy.ndarray`.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4Dr0at41KcAU" }, "source": [ "## Configure the dataset for performance\n", "\n", "Make sure to use buffered prefetching, so you can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data:\n", "\n", "- `Dataset.cache` keeps the images in memory after they're loaded off disk during the first epoch. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache.\n", "- `Dataset.prefetch` overlaps data preprocessing and model execution while training.\n", "\n", "Interested readers can learn more about both methods, as well as how to cache data to disk in the *Prefetching* section of the [Better performance with the tf.data API](../../guide/data_performance.ipynb) guide." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nOjJSm7DKoZA" }, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)\n", "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "8GUnmPF4JvEf" }, "source": [ "## Standardize the data" ] }, { "cell_type": "markdown", "metadata": { "id": "e56VXHMWJxYT" }, "source": [ "The RGB channel values are in the `[0, 255]` range. This is not ideal for a neural network; in general you should seek to make your input values small.\n", "\n", "Here, you will standardize values to be in the `[0, 1]` range by using `tf.keras.layers.Rescaling`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PEYxo2CTJvY9" }, "outputs": [], "source": [ "normalization_layer = layers.Rescaling(1./255)" ] }, { "cell_type": "markdown", "metadata": { "id": "Bl4RmanbJ4g0" }, "source": [ "There are two ways to use this layer. You can apply it to the dataset by calling `Dataset.map`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X9o9ESaJJ502" }, "outputs": [], "source": [ "normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))\n", "image_batch, labels_batch = next(iter(normalized_ds))\n", "first_image = image_batch[0]\n", "# Notice the pixel values are now in `[0,1]`.\n", "print(np.min(first_image), np.max(first_image))" ] }, { "cell_type": "markdown", "metadata": { "id": "XWEOmRSBJ9J8" }, "source": [ "Or, you can include the layer inside your model definition, which can simplify deployment. Use the second approach here." ] }, { "cell_type": "markdown", "metadata": { "id": "XsRk1xCwKZR4" }, "source": [ "Note: You previously resized images using the `image_size` argument of `tf.keras.utils.image_dataset_from_directory`. If you want to include the resizing logic in your model as well, you can use the `tf.keras.layers.Resizing` layer." ] }, { "cell_type": "markdown", "metadata": { "id": "WcUTyDOPKucd" }, "source": [ "## A basic Keras model\n", "\n", "### Create the model\n", "\n", "The Keras [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) model consists of three convolution blocks (`tf.keras.layers.Conv2D`) with a max pooling layer (`tf.keras.layers.MaxPooling2D`) in each of them. There's a fully-connected layer (`tf.keras.layers.Dense`) with 128 units on top of it that is activated by a ReLU activation function (`'relu'`). This model has not been tuned for high accuracy; the goal of this tutorial is to show a standard approach." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QR6argA1K074" }, "outputs": [], "source": [ "num_classes = len(class_names)\n", "\n", "model = Sequential([\n", " layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes)\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "EaKFzz72Lqpg" }, "source": [ "### Compile the model\n", "\n", "For this tutorial, choose the `tf.keras.optimizers.Adam` optimizer and `tf.keras.losses.SparseCategoricalCrossentropy` loss function. To view training and validation accuracy for each training epoch, pass the `metrics` argument to `Model.compile`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jloGNS1MLx3A" }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "markdown", "metadata": { "id": "aMJ4DnuJL55A" }, "source": [ "### Model summary\n", "\n", "View all the layers of the network using the Keras `Model.summary` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "llLYH-BXL7Xe" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "NiYHcbvaL9H-" }, "source": [ "### Train the model" ] }, { "cell_type": "markdown", "metadata": { "id": "j30F69T4sIVN" }, "source": [ "Train the model for 10 epochs with the Keras `Model.fit` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5fWToCqYMErH" }, "outputs": [], "source": [ "epochs=10\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "SyFKdQpXMJT4" }, "source": [ "## Visualize training results" ] }, { "cell_type": "markdown", "metadata": { "id": "dFvOvmAmMK9w" }, "source": [ "Create plots of the loss and accuracy on the training and validation sets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jWnopEChMMCn" }, "outputs": [], "source": [ "acc = history.history['accuracy']\n", "val_acc = history.history['val_accuracy']\n", "\n", "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", "epochs_range = range(epochs)\n", "\n", "plt.figure(figsize=(8, 8))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs_range, acc, label='Training Accuracy')\n", "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", "plt.legend(loc='lower right')\n", "plt.title('Training and Validation Accuracy')\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs_range, loss, label='Training Loss')\n", "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.legend(loc='upper right')\n", "plt.title('Training and Validation Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "hO_jT7HwMrEn" }, "source": [ "The plots show that training accuracy and validation accuracy are off by large margins, and the model has achieved only around 60% accuracy on the validation set.\n", "\n", "The following tutorial sections show how to inspect what went wrong and try to increase the overall performance of the model." ] }, { "cell_type": "markdown", "metadata": { "id": "hqtyGodAMvNV" }, "source": [ "## Overfitting" ] }, { "cell_type": "markdown", "metadata": { "id": "ixsz9XFfMxcu" }, "source": [ "In the plots above, the training accuracy is increasing linearly over time, whereas validation accuracy stalls around 60% in the training process. Also, the difference in accuracy between training and validation accuracy is noticeable—a sign of [overfitting](https://www.tensorflow.org/tutorials/keras/overfit_and_underfit).\n", "\n", "When there are a small number of training examples, the model sometimes learns from noises or unwanted details from training examples—to an extent that it negatively impacts the performance of the model on new examples. This phenomenon is known as overfitting. It means that the model will have a difficult time generalizing on a new dataset.\n", "\n", "There are multiple ways to fight overfitting in the training process. In this tutorial, you'll use *data augmentation* and add *dropout* to your model." ] }, { "cell_type": "markdown", "metadata": { "id": "BDMfYqwmM1C-" }, "source": [ "## Data augmentation" ] }, { "cell_type": "markdown", "metadata": { "id": "GxYwix81M2YO" }, "source": [ "Overfitting generally occurs when there are a small number of training examples. [Data augmentation](./data_augmentation.ipynb) takes the approach of generating additional training data from your existing examples by augmenting them using random transformations that yield believable-looking images. This helps expose the model to more aspects of the data and generalize better.\n", "\n", "You will implement data augmentation using the following Keras preprocessing layers: `tf.keras.layers.RandomFlip`, `tf.keras.layers.RandomRotation`, and `tf.keras.layers.RandomZoom`. These can be included inside your model like other layers, and run on the GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9J80BAbIMs21" }, "outputs": [], "source": [ "data_augmentation = keras.Sequential(\n", " [\n", " layers.RandomFlip(\"horizontal\",\n", " input_shape=(img_height,\n", " img_width,\n", " 3)),\n", " layers.RandomRotation(0.1),\n", " layers.RandomZoom(0.1),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "PN4k1dK3S6eV" }, "source": [ "Visualize a few augmented examples by applying data augmentation to the same image several times:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7Z90k539S838" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for images, _ in train_ds.take(1):\n", " for i in range(9):\n", " augmented_images = data_augmentation(images)\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(augmented_images[0].numpy().astype(\"uint8\"))\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tsjXCBLYYNs5" }, "source": [ "You will add data augmentation to your model before training in the next step." ] }, { "cell_type": "markdown", "metadata": { "id": "ZeD3bXepYKXs" }, "source": [ "## Dropout\n", "\n", "Another technique to reduce overfitting is to introduce [dropout](https://developers.google.com/machine-learning/glossary#dropout_regularization){:.external} regularization to the network.\n", "\n", "When you apply dropout to a layer, it randomly drops out (by setting the activation to zero) a number of output units from the layer during the training process. Dropout takes a fractional number as its input value, in the form such as 0.1, 0.2, 0.4, etc. This means dropping out 10%, 20% or 40% of the output units randomly from the applied layer.\n", "\n", "Create a new neural network with `tf.keras.layers.Dropout` before training it using the augmented images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2Zeg8zsqXCsm" }, "outputs": [], "source": [ "model = Sequential([\n", " data_augmentation,\n", " layers.Rescaling(1./255),\n", " layers.Conv2D(16, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(32, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Conv2D(64, 3, padding='same', activation='relu'),\n", " layers.MaxPooling2D(),\n", " layers.Dropout(0.2),\n", " layers.Flatten(),\n", " layers.Dense(128, activation='relu'),\n", " layers.Dense(num_classes, name=\"outputs\")\n", "])" ] }, { "cell_type": "markdown", "metadata": { "id": "L4nEcuqgZLbi" }, "source": [ "## Compile and train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EvyAINs9ZOmJ" }, "outputs": [], "source": [ "model.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wWLkKoKjZSoC" }, "outputs": [], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LWS-vvNaZDag" }, "outputs": [], "source": [ "epochs = 15\n", "history = model.fit(\n", " train_ds,\n", " validation_data=val_ds,\n", " epochs=epochs\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Lkdl8VsBbZOu" }, "source": [ "## Visualize training results\n", "\n", "After applying data augmentation and `tf.keras.layers.Dropout`, there is less overfitting than before, and training and validation accuracy are closer aligned:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dduoLfKsZVIA" }, "outputs": [], "source": [ "acc = history.history['accuracy']\n", "val_acc = history.history['val_accuracy']\n", "\n", "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", "epochs_range = range(epochs)\n", "\n", "plt.figure(figsize=(8, 8))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(epochs_range, acc, label='Training Accuracy')\n", "plt.plot(epochs_range, val_acc, label='Validation Accuracy')\n", "plt.legend(loc='lower right')\n", "plt.title('Training and Validation Accuracy')\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(epochs_range, loss, label='Training Loss')\n", "plt.plot(epochs_range, val_loss, label='Validation Loss')\n", "plt.legend(loc='upper right')\n", "plt.title('Training and Validation Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "dtv5VbaVb-3W" }, "source": [ "## Predict on new data" ] }, { "cell_type": "markdown", "metadata": { "id": "10buWpJbcCQz" }, "source": [ "Use your model to classify an image that wasn't included in the training or validation sets." ] }, { "cell_type": "markdown", "metadata": { "id": "NKgMZ4bDcHf7" }, "source": [ "Note: Data augmentation and dropout layers are inactive at inference time." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dC40sRITBSsQ" }, "outputs": [], "source": [ "sunflower_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg\"\n", "sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)\n", "\n", "img = tf.keras.utils.load_img(\n", " sunflower_path, target_size=(img_height, img_width)\n", ")\n", "img_array = tf.keras.utils.img_to_array(img)\n", "img_array = tf.expand_dims(img_array, 0) # Create a batch\n", "\n", "predictions = model.predict(img_array)\n", "score = tf.nn.softmax(predictions[0])\n", "\n", "print(\n", " \"This image most likely belongs to {} with a {:.2f} percent confidence.\"\n", " .format(class_names[np.argmax(score)], 100 * np.max(score))\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "aOc3PZ2N2r18" }, "source": [ "## Use TensorFlow Lite\n", "\n", "TensorFlow Lite is a set of tools that enables on-device machine learning by helping developers run their models on mobile, embedded, and edge devices." ] }, { "cell_type": "markdown", "metadata": { "id": "cThu25rh4LPP" }, "source": [ "### Convert the Keras Sequential model to a TensorFlow Lite model\n", "\n", "To use the trained model with on-device applications, first [convert it](https://www.tensorflow.org/lite/models/convert) to a smaller and more efficient model format called a [TensorFlow Lite](https://www.tensorflow.org/lite/) model.\n", "\n", "In this example, take the trained Keras Sequential model and use `tf.lite.TFLiteConverter.from_keras_model` to generate a [TensorFlow Lite](https://www.tensorflow.org/lite/) model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXo6ftuL2ufx" }, "outputs": [], "source": [ "# Convert the model.\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "tflite_model = converter.convert()\n", "\n", "# Save the model.\n", "with open('model.tflite', 'wb') as f:\n", " f.write(tflite_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "4R26OU4gGKhh" }, "source": [ "The TensorFlow Lite model you saved in the previous step can contain several function signatures. The Keras model converter API uses the default signature automatically. Learn more about [TensorFlow Lite signatures](https://www.tensorflow.org/lite/guide/signatures)." ] }, { "cell_type": "markdown", "metadata": { "id": "7fjQfXaV2l-5" }, "source": [ "### Run the TensorFlow Lite model\n", "\n", "You can access the TensorFlow Lite saved model signatures in Python via the `tf.lite.Interpreter` class.\n", "\n", "Load the model with the `Interpreter`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cHYcip_FOaHq" }, "outputs": [], "source": [ "TF_MODEL_FILE_PATH = 'model.tflite' # The default path to the saved TensorFlow Lite model\n", "\n", "interpreter = tf.lite.Interpreter(model_path=TF_MODEL_FILE_PATH)" ] }, { "cell_type": "markdown", "metadata": { "id": "nPUXY6BdHDHo" }, "source": [ "Print the signatures from the converted model to obtain the names of the inputs (and outputs):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZdDl00E2OaHq" }, "outputs": [], "source": [ "interpreter.get_signature_list()" ] }, { "cell_type": "markdown", "metadata": { "id": "4eVFqT0je3YG" }, "source": [ "In this example, you have one default signature called `serving_default`. In addition, the name of the `'inputs'` is `'sequential_1_input'`, while the `'outputs'` are called `'outputs'`. You can look up these first and last Keras layer names when running `Model.summary`, as demonstrated earlier in this tutorial.\n", "\n", "Now you can test the loaded TensorFlow Model by performing inference on a sample image with `tf.lite.Interpreter.get_signature_runner` by passing the signature name as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yFoT_7W_OaHq" }, "outputs": [], "source": [ "classify_lite = interpreter.get_signature_runner('serving_default')\n", "classify_lite" ] }, { "cell_type": "markdown", "metadata": { "id": "b1mfRcBOnEx0" }, "source": [ "Similar to what you did earlier in the tutorial, you can use the TensorFlow Lite model to classify images that weren't included in the training or validation sets.\n", "\n", "You have already tensorized that image and saved it as `img_array`. Now, pass it to the first argument (the name of the `'inputs'`) of the loaded TensorFlow Lite model (`predictions_lite`), compute softmax activations, and then print the prediction for the class with the highest computed probability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sEqR27YcnFvc" }, "outputs": [], "source": [ "predictions_lite = classify_lite(sequential_1_input=img_array)['outputs']\n", "score_lite = tf.nn.softmax(predictions_lite)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZKP_GFeKUWb5" }, "outputs": [], "source": [ "print(\n", " \"This image most likely belongs to {} with a {:.2f} percent confidence.\"\n", " .format(class_names[np.argmax(score_lite)], 100 * np.max(score_lite))\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Poz_iYgeUg_U" }, "source": [ "The prediction generated by the lite model should be almost identical to the predictions generated by the original model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "InXXDJL8UYC1" }, "outputs": [], "source": [ "print(np.max(np.abs(predictions - predictions_lite)))" ] }, { "cell_type": "markdown", "metadata": { "id": "5hJzY8XijM7N" }, "source": [ "Of the five classes—`'daisy'`, `'dandelion'`, `'roses'`, `'sunflowers'`, and `'tulips'`—the model should predict the image belongs to sunflowers, which is the same result as before the TensorFlow Lite conversion.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "1RlfCY9v2_ir" }, "source": [ "## Next steps\n", "\n", "This tutorial showed how to train a model for image classification, test it, convert it to the TensorFlow Lite format for on-device applications (such as an image classification app), and perform inference with the TensorFlow Lite model with the Python API.\n", "\n", "You can learn more about TensorFlow Lite through [tutorials](https://www.tensorflow.org/lite/tutorials) and [guides](https://www.tensorflow.org/lite/guide)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "classification.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }