{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "3XX46cTrh6iD" }, "source": [ "##### Copyright 2021 The TensorFlow Hub Authors. \n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sKrlWr6Kh-mF" }, "outputs": [], "source": [ "#@title Copyright 2021 The TensorFlow Hub Authors. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================" ] }, { "cell_type": "markdown", "metadata": { "id": "DMVmlJ0fAMkH" }, "source": [ "# Fine tuning models for plant disease detection\n" ] }, { "cell_type": "markdown", "metadata": { "id": "hk5u_9KN1m-t" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", " \n", " See TF Hub models\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "OEHq-hV5sWYO" }, "source": [ "This notebook shows you how to **fine-tune CropNet models from TensorFlow Hub** on a dataset from TFDS or your own crop disease detection dataset.\n", "\n", "You will:\n", "- Load the TFDS cassava dataset or your own data\n", "- Enrich the data with unknown (negative) examples to get a more robust model\n", "- Apply image augmentations to the data\n", "- Load and fine tune a [CropNet model](https://tfhub.dev/s?module-type=image-feature-vector&q=cropnet) from TF Hub\n", "- Export a TFLite model, ready to be deployed on your app with [Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), [MLKit](https://developers.google.com/ml-kit/vision/image-labeling/custom-models/android) or [TFLite](https://www.tensorflow.org/lite/guide/inference) directly" ] }, { "cell_type": "markdown", "metadata": { "id": "dQvS4p807mZf" }, "source": [ "## Imports and Dependencies\n", "\n", "Before starting, you'll need to install some of the dependencies that will be needed like [Model Maker](https://www.tensorflow.org/lite/guide/model_maker#installation) and the latest version of TensorFlow Datasets." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5BDTEMtexXE3" }, "outputs": [], "source": [ "!sudo apt install -q libportaudio2\n", "## image_classifier library requires numpy <= 1.23.5\n", "!pip install \"numpy<=1.23.5\"\n", "!pip install --use-deprecated=legacy-resolver tflite-model-maker-nightly\n", "!pip install -U tensorflow-datasets\n", "## scann library requires tensorflow < 2.9.0\n", "!pip install \"tensorflow<2.9.0\"\n", "!pip install \"tensorflow-datasets~=4.8.0\" # protobuf>=3.12.2\n", "!pip install tensorflow-metadata~=1.10.0 # protobuf>=3.13\n", "## tensorflowjs requires packaging < 20.10\n", "!pip install \"packaging<20.10\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nekG9Iwgxbx0" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import os\n", "import seaborn as sns\n", "\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat\n", "from tensorflow_examples.lite.model_maker.core.task import image_preprocessing\n", "\n", "from tflite_model_maker import image_classifier\n", "from tflite_model_maker import ImageClassifierDataLoader\n", "from tflite_model_maker.image_classifier import ModelSpec" ] }, { "cell_type": "markdown", "metadata": { "id": "fV0k2Q4x4N_4" }, "source": [ "## Load a TFDS dataset to fine-tune on\n", "\n", "Lets use the publicly available [Cassava Leaf Disease dataset](https://www.tensorflow.org/datasets/catalog/cassava) from TFDS." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TTaD5W_1xjUz" }, "outputs": [], "source": [ "tfds_name = 'cassava'\n", "(ds_train, ds_validation, ds_test), ds_info = tfds.load(\n", " name=tfds_name,\n", " split=['train', 'validation', 'test'],\n", " with_info=True,\n", " as_supervised=True)\n", "TFLITE_NAME_PREFIX = tfds_name" ] }, { "cell_type": "markdown", "metadata": { "id": "xDuDGUAxyHtA" }, "source": [ "## Or alternatively load your own data to fine-tune on\n", "\n", "Instead of using a TFDS dataset, you can also train on your own data. This code snippet shows how to load your own custom dataset. See [this](https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder) link for the supported structure of the data. An example is provided here using the publicly available [Cassava Leaf Disease dataset](https://www.tensorflow.org/datasets/catalog/cassava)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k003tLvflHpC" }, "outputs": [], "source": [ "# data_root_dir = tf.keras.utils.get_file(\n", "# 'cassavaleafdata.zip',\n", "# 'https://storage.googleapis.com/emcassavadata/cassavaleafdata.zip',\n", "# extract=True)\n", "# data_root_dir = os.path.splitext(data_root_dir)[0] # Remove the .zip extension\n", "\n", "# builder = tfds.ImageFolder(data_root_dir)\n", "\n", "# ds_info = builder.info\n", "# ds_train = builder.as_dataset(split='train', as_supervised=True)\n", "# ds_validation = builder.as_dataset(split='validation', as_supervised=True)\n", "# ds_test = builder.as_dataset(split='test', as_supervised=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "hs3XCVLo4Fa1" }, "source": [ "## Visualize samples from train split\n", "\n", "Let's take a look at some examples from the dataset including the class id and the class name for the image samples and their labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "89GkD60Eyfe0" }, "outputs": [], "source": [ "_ = tfds.show_examples(ds_train, ds_info)" ] }, { "cell_type": "markdown", "metadata": { "id": "-KW-n0lV4AZ-" }, "source": [ "## Add images to be used as Unknown examples from TFDS datasets\n", "\n", "Add additional unknown (negative) examples to the training dataset and assign a new unknown class label number to them. The goal is to have a model that, when used in practice (e.g. in the field), has the option of predicting \"Unknown\" when it sees something unexpected.\n", "\n", "Below you can see a list of datasets that will be used to sample the additional unknown imagery. It includes 3 completely different datasets to increase diversity. One of them is a beans leaf disease dataset, so that the model has exposure to diseased plants other than cassava.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SYDMjRhDkDnd" }, "outputs": [], "source": [ "UNKNOWN_TFDS_DATASETS = [{\n", " 'tfds_name': 'imagenet_v2/matched-frequency',\n", " 'train_split': 'test[:80%]',\n", " 'test_split': 'test[80%:]',\n", " 'num_examples_ratio_to_normal': 1.0,\n", "}, {\n", " 'tfds_name': 'oxford_flowers102',\n", " 'train_split': 'train',\n", " 'test_split': 'test',\n", " 'num_examples_ratio_to_normal': 1.0,\n", "}, {\n", " 'tfds_name': 'beans',\n", " 'train_split': 'train',\n", " 'test_split': 'test',\n", " 'num_examples_ratio_to_normal': 1.0,\n", "}]" ] }, { "cell_type": "markdown", "metadata": { "id": "XUM_d0evktGi" }, "source": [ "The UNKNOWN datasets are also loaded from TFDS." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5DdWgBTe8uKR" }, "outputs": [], "source": [ "# Load unknown datasets.\n", "weights = [\n", " spec['num_examples_ratio_to_normal'] for spec in UNKNOWN_TFDS_DATASETS\n", "]\n", "num_unknown_train_examples = sum(\n", " int(w * ds_train.cardinality().numpy()) for w in weights)\n", "ds_unknown_train = tf.data.Dataset.sample_from_datasets([\n", " tfds.load(\n", " name=spec['tfds_name'], split=spec['train_split'],\n", " as_supervised=True).repeat(-1) for spec in UNKNOWN_TFDS_DATASETS\n", "], weights).take(num_unknown_train_examples)\n", "ds_unknown_train = ds_unknown_train.apply(\n", " tf.data.experimental.assert_cardinality(num_unknown_train_examples))\n", "ds_unknown_tests = [\n", " tfds.load(\n", " name=spec['tfds_name'], split=spec['test_split'], as_supervised=True)\n", " for spec in UNKNOWN_TFDS_DATASETS\n", "]\n", "ds_unknown_test = ds_unknown_tests[0]\n", "for ds in ds_unknown_tests[1:]:\n", " ds_unknown_test = ds_unknown_test.concatenate(ds)\n", "\n", "# All examples from the unknown datasets will get a new class label number.\n", "num_normal_classes = len(ds_info.features['label'].names)\n", "unknown_label_value = tf.convert_to_tensor(num_normal_classes, tf.int64)\n", "ds_unknown_train = ds_unknown_train.map(lambda image, _:\n", " (image, unknown_label_value))\n", "ds_unknown_test = ds_unknown_test.map(lambda image, _:\n", " (image, unknown_label_value))\n", "\n", "# Merge the normal train dataset with the unknown train dataset.\n", "weights = [\n", " ds_train.cardinality().numpy(),\n", " ds_unknown_train.cardinality().numpy()\n", "]\n", "ds_train_with_unknown = tf.data.Dataset.sample_from_datasets(\n", " [ds_train, ds_unknown_train], [float(w) for w in weights])\n", "ds_train_with_unknown = ds_train_with_unknown.apply(\n", " tf.data.experimental.assert_cardinality(sum(weights)))\n", "\n", "print((f\"Added {ds_unknown_train.cardinality().numpy()} negative examples.\"\n", " f\"Training dataset has now {ds_train_with_unknown.cardinality().numpy()}\"\n", " ' examples in total.'))" ] }, { "cell_type": "markdown", "metadata": { "id": "am6eKbzt7raH" }, "source": [ "## Apply augmentations" ] }, { "cell_type": "markdown", "metadata": { "id": "sxIUP0Flk35V" }, "source": [ "For all the images, to make them more diverse, you'll apply some augmentation, like changes in:\n", "- Brightness\n", "- Contrast\n", "- Saturation\n", "- Hue\n", "- Crop\n", "\n", "These types of augmentations help make the model more robust to variations in image inputs.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q_BiOkXjqRju" }, "outputs": [], "source": [ "def random_crop_and_random_augmentations_fn(image):\n", " # preprocess_for_train does random crop and resize internally.\n", " image = image_preprocessing.preprocess_for_train(image)\n", " image = tf.image.random_brightness(image, 0.2)\n", " image = tf.image.random_contrast(image, 0.5, 2.0)\n", " image = tf.image.random_saturation(image, 0.75, 1.25)\n", " image = tf.image.random_hue(image, 0.1)\n", " return image\n", "\n", "\n", "def random_crop_fn(image):\n", " # preprocess_for_train does random crop and resize internally.\n", " image = image_preprocessing.preprocess_for_train(image)\n", " return image\n", "\n", "\n", "def resize_and_center_crop_fn(image):\n", " image = tf.image.resize(image, (256, 256))\n", " image = image[16:240, 16:240]\n", " return image\n", "\n", "\n", "no_augment_fn = lambda image: image\n", "\n", "train_augment_fn = lambda image, label: (\n", " random_crop_and_random_augmentations_fn(image), label)\n", "eval_augment_fn = lambda image, label: (resize_and_center_crop_fn(image), label)" ] }, { "cell_type": "markdown", "metadata": { "id": "RUfqE1c3l6my" }, "source": [ "To apply the augmentation, it uses the `map` method from the Dataset class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Uq-NCtaH_h8j" }, "outputs": [], "source": [ "ds_train_with_unknown = ds_train_with_unknown.map(train_augment_fn)\n", "ds_validation = ds_validation.map(eval_augment_fn)\n", "ds_test = ds_test.map(eval_augment_fn)\n", "ds_unknown_test = ds_unknown_test.map(eval_augment_fn)" ] }, { "cell_type": "markdown", "metadata": { "id": "DvnwolLiCqYX" }, "source": [ "## Wrap the data into Model Maker friendly format\n", "\n", "To use these dataset with Model Maker, they need to be in a ImageClassifierDataLoader class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OXPWEDFDRlVu" }, "outputs": [], "source": [ "label_names = ds_info.features['label'].names + ['UNKNOWN']\n", "\n", "train_data = ImageClassifierDataLoader(ds_train_with_unknown,\n", " ds_train_with_unknown.cardinality(),\n", " label_names)\n", "validation_data = ImageClassifierDataLoader(ds_validation,\n", " ds_validation.cardinality(),\n", " label_names)\n", "test_data = ImageClassifierDataLoader(ds_test, ds_test.cardinality(),\n", " label_names)\n", "unknown_test_data = ImageClassifierDataLoader(ds_unknown_test,\n", " ds_unknown_test.cardinality(),\n", " label_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "j2iDwq2Njpb_" }, "source": [ "## Run training\n", "\n", "[TensorFlow Hub](https://tfhub.dev) has multiple models available for Transfer Learning.\n", "\n", "Here you can choose one and you can also keep experimenting with other ones to try to get better results.\n", "\n", "If you want even more models to try, you can add them from this [collection](https://tfhub.dev/google/collections/image/1).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "5UhNpR0Ex_5-" }, "outputs": [], "source": [ "#@title Choose a base model\n", "\n", "model_name = 'mobilenet_v3_large_100_224' #@param ['cropnet_cassava', 'cropnet_concat', 'cropnet_imagenet', 'mobilenet_v3_large_100_224']\n", "\n", "map_model_name = {\n", " 'cropnet_cassava':\n", " 'https://tfhub.dev/google/cropnet/feature_vector/cassava_disease_V1/1',\n", " 'cropnet_concat':\n", " 'https://tfhub.dev/google/cropnet/feature_vector/concat/1',\n", " 'cropnet_imagenet':\n", " 'https://tfhub.dev/google/cropnet/feature_vector/imagenet/1',\n", " 'mobilenet_v3_large_100_224':\n", " 'https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/feature_vector/5',\n", "}\n", "\n", "model_handle = map_model_name[model_name]" ] }, { "cell_type": "markdown", "metadata": { "id": "Y1ecXlQgR5Uk" }, "source": [ "To fine tune the model, you will use Model Maker. This makes the overall solution easier since after the training of the model, it'll also convert it to TFLite.\n", "\n", "Model Maker makes this conversion be the best one possible and with all the necessary information to easily deploy the model on-device later.\n", "\n", "The model spec is how you tell Model Maker which base model you'd like to use." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L8P-VTqJ8GaF" }, "outputs": [], "source": [ "image_model_spec = ModelSpec(uri=model_handle)" ] }, { "cell_type": "markdown", "metadata": { "id": "AnWN3kk6jCHf" }, "source": [ "One important detail here is setting `train_whole_model` which will make the base model fine tuned during training. This makes the process slower but the final model has a higher accuracy. Setting `shuffle` will make sure the model sees the data in a random shuffled order which is a best practice for model learning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KRbSDbnA6Xap" }, "outputs": [], "source": [ "model = image_classifier.create(\n", " train_data,\n", " model_spec=image_model_spec,\n", " batch_size=128,\n", " learning_rate=0.03,\n", " epochs=5,\n", " shuffle=True,\n", " train_whole_model=True,\n", " validation_data=validation_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "buFDW0izBqIQ" }, "source": [ "## Evaluate model on test split" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OYIZ1rlV7lxm" }, "outputs": [], "source": [ "model.evaluate(test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "YJaReZ_OVU71" }, "source": [ "To have an even better understanding of the fine tuned model, it's good to analyse the confusion matrix. This will show how often one class is predicted as another." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o9_vs1nNKOLF" }, "outputs": [], "source": [ "def predict_class_label_number(dataset):\n", " \"\"\"Runs inference and returns predictions as class label numbers.\"\"\"\n", " rev_label_names = {l: i for i, l in enumerate(label_names)}\n", " return [\n", " rev_label_names[o[0][0]]\n", " for o in model.predict_top_k(dataset, batch_size=128)\n", " ]\n", "\n", "def show_confusion_matrix(cm, labels):\n", " plt.figure(figsize=(10, 8))\n", " sns.heatmap(cm, xticklabels=labels, yticklabels=labels, \n", " annot=True, fmt='g')\n", " plt.xlabel('Prediction')\n", " plt.ylabel('Label')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7BWZCKerCNF_" }, "outputs": [], "source": [ "confusion_mtx = tf.math.confusion_matrix(\n", " list(ds_test.map(lambda x, y: y)),\n", " predict_class_label_number(test_data),\n", " num_classes=len(label_names))\n", "\n", "show_confusion_matrix(confusion_mtx, label_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "ksu9BFULBvmj" }, "source": [ "## Evaluate model on unknown test data\n", "\n", "In this evaluation we expect the model to have accuracy of almost 1. All images the model is tested on are not related to the normal dataset and hence we expect the model to predict the \"Unknown\" class label." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5wvZwliZcJP" }, "outputs": [], "source": [ "model.evaluate(unknown_test_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "jm47Odo5Vaiq" }, "source": [ "Print the confusion matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E_gEX3oWH1YT" }, "outputs": [], "source": [ "unknown_confusion_mtx = tf.math.confusion_matrix(\n", " list(ds_unknown_test.map(lambda x, y: y)),\n", " predict_class_label_number(unknown_test_data),\n", " num_classes=len(label_names))\n", "\n", "show_confusion_matrix(unknown_confusion_mtx, label_names)" ] }, { "cell_type": "markdown", "metadata": { "id": "o2agDx2fCHyd" }, "source": [ "## Export the model as TFLite and SavedModel\n", "\n", "Now we can export the trained models in TFLite and SavedModel formats for deploying on-device and using for inference in TensorFlow." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bAFvBmMr7owW" }, "outputs": [], "source": [ "tflite_filename = f'{TFLITE_NAME_PREFIX}_model_{model_name}.tflite'\n", "model.export(export_dir='.', tflite_filename=tflite_filename)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pz0-6To2C4yM" }, "outputs": [], "source": [ "# Export saved model version.\n", "model.export(export_dir='.', export_format=ExportFormat.SAVED_MODEL)" ] }, { "cell_type": "markdown", "metadata": { "id": "4V4GdQqxjEU7" }, "source": [ "## Next steps\n", "\n", "The model that you've just trained can be used on mobile devices and even deployed in the field!\n", "\n", "**To download the model, click the folder icon for the Files menu on the left side of the colab, and choose the download option.**\n", "\n", "The same technique used here could be applied to other plant diseases tasks that might be more suitable for your use case or any other type of image classification task. If you want to follow up and deploy on an Android app, you can continue on this [Android quickstart guide](https://www.tensorflow.org/lite/android/quickstart)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "3XX46cTrh6iD", "xDuDGUAxyHtA" ], "name": "cropnet_on_device.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }