{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "nEF2_0sn2Tua" }, "source": [ "# Image Classification using BigTransfer (BiT)\n", "\n", "**Author:** [Sayan Nath](https://twitter.com/sayannath2350)
\n", "**Date created:** 2021/09/24
\n", "**Last modified:** 2021/09/24
\n", "**Description:** BigTransfer (BiT) State-of-the-art transfer learning for image classification." ] }, { "cell_type": "markdown", "metadata": { "id": "5LUKJJ0l2Tuc" }, "source": [ "## Introduction\n", "\n", "BigTransfer (also known as BiT) is a state-of-the-art transfer learning method for image\n", "classification. Transfer of pre-trained representations improves sample efficiency and\n", "simplifies hyperparameter tuning when training deep neural networks for vision. BiT\n", "revisit the paradigm of pre-training on large supervised datasets and fine-tuning the\n", "model on a target task. The importance of appropriately choosing normalization layers and\n", "scaling the architecture capacity as the amount of pre-training data increases.\n", "\n", "BigTransfer(BiT) is trained on public datasets, along with code in\n", "[TF2, Jax and Pytorch](https://github.com/google-research/big_transfer). This will help anyone to reach\n", "state of the art performance on their task of interest, even with just a handful of\n", "labeled images per class.\n", "\n", "You can find BiT models pre-trained on\n", "[ImageNet](https://image-net.org/challenges/LSVRC/2012/index) and ImageNet-21k in\n", "[TFHub](https://tfhub.dev/google/collections/bit/1) as TensorFlow2 SavedModels that you\n", "can use easily as Keras Layers. There are a variety of sizes ranging from a standard\n", "ResNet50 to a ResNet152x4 (152 layers deep, 4x wider than a typical ResNet50) for users\n", "with larger computational and memory budgets but higher accuracy requirements.\n", "\n", "![](https://i.imgur.com/XeWVfe7.jpeg)\n", "Figure: The x-axis shows the number of images used per class, ranging from 1 to the full\n", "dataset. On the plots on the left, the curve in blue above is our BiT-L model, whereas\n", "the curve below is a ResNet-50 pre-trained on ImageNet (ILSVRC-2012)." ] }, { "cell_type": "markdown", "metadata": { "id": "hvzPOZc72Tud" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x2bXlydz2Tue" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "import tensorflow_hub as hub\n", "import tensorflow_datasets as tfds\n", "\n", "tfds.disable_progress_bar()\n", "\n", "SEEDS = 42\n", "\n", "np.random.seed(SEEDS)\n", "tf.random.set_seed(SEEDS)" ] }, { "cell_type": "markdown", "metadata": { "id": "iSr9snUg2Tuf" }, "source": [ "## Gather Flower Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jTf0_ElG2Tuf" }, "outputs": [], "source": [ "train_ds, validation_ds = tfds.load(\n", " \"tf_flowers\",\n", " split=[\"train[:85%]\", \"train[85%:]\"],\n", " as_supervised=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "RW6zvyF22Tuh" }, "source": [ "## Visualise the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6W-AiCml2Tui" }, "outputs": [], "source": [ "plt.figure(figsize=(10, 10))\n", "for i, (image, label) in enumerate(train_ds.take(9)):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(image)\n", " plt.title(int(label))\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0CsqwOoT2Tui" }, "source": [ "## Define hyperparameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5hy87EV42Tuj" }, "outputs": [], "source": [ "RESIZE_TO = 384\n", "CROP_TO = 224\n", "BATCH_SIZE = 64\n", "STEPS_PER_EPOCH = 10\n", "AUTO = tf.data.AUTOTUNE # optimise the pipeline performance\n", "NUM_CLASSES = 5 # number of classes\n", "SCHEDULE_LENGTH = (\n", " 500 # we will train on lower resolution images and will still attain good results\n", ")\n", "SCHEDULE_BOUNDARIES = [\n", " 200,\n", " 300,\n", " 400,\n", "] # more the dataset size the schedule length increase" ] }, { "cell_type": "markdown", "metadata": { "id": "EV7hIbAh2Tuk" }, "source": [ "The hyperparamteres like `SCHEDULE_LENGTH` and `SCHEDULE_BOUNDARIES` are determined based\n", "on empirical results. The method has been explained in the [original\n", "paper](https://arxiv.org/abs/1912.11370) and in their [Google AI Blog\n", "Post](https://ai.googleblog.com/2020/05/open-sourcing-bit-exploring-large-scale.html).\n", "\n", "The `SCHEDULE_LENGTH` is aslo determined whether to use [MixUp\n", "Augmentation](https://arxiv.org/abs/1710.09412) or not. You can also find an easy MixUp\n", "Implementation in [Keras Coding Examples](https://keras.io/examples/vision/mixup/).\n", "\n", "![](https://i.imgur.com/oSaIBYZ.jpeg)" ] }, { "cell_type": "markdown", "metadata": { "id": "9fc7_2AS2Tul" }, "source": [ "## Define preprocessing helper functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zUQI1SWo2Tum" }, "outputs": [], "source": [ "SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE\n", "\n", "\n", "@tf.function\n", "def preprocess_train(image, label):\n", " image = tf.image.random_flip_left_right(image)\n", " image = tf.image.resize(image, (RESIZE_TO, RESIZE_TO))\n", " image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))\n", " image = image / 255.0\n", " return (image, label)\n", "\n", "\n", "@tf.function\n", "def preprocess_test(image, label):\n", " image = tf.image.resize(image, (RESIZE_TO, RESIZE_TO))\n", " image = image / 255.0\n", " return (image, label)\n", "\n", "\n", "DATASET_NUM_TRAIN_EXAMPLES = train_ds.cardinality().numpy()\n", "\n", "repeat_count = int(\n", " SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH\n", ")\n", "repeat_count += 50 + 1 # To ensure at least there are 50 epochs of training" ] }, { "cell_type": "markdown", "metadata": { "id": "7b8-1Z7y2Tum" }, "source": [ "## Define the data pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t8xK0KS12Tum" }, "outputs": [], "source": [ "# Training pipeline\n", "pipeline_train = (\n", " train_ds.shuffle(10000)\n", " .repeat(repeat_count) # Repeat dataset_size / num_steps\n", " .map(preprocess_train, num_parallel_calls=AUTO)\n", " .batch(BATCH_SIZE)\n", " .prefetch(AUTO)\n", ")\n", "\n", "# Validation pipeline\n", "pipeline_validation = (\n", " validation_ds.map(preprocess_test, num_parallel_calls=AUTO)\n", " .batch(BATCH_SIZE)\n", " .prefetch(AUTO)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "xUMadZeC2Tum" }, "source": [ "## Visualise the training samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TrILDQhX2Tun" }, "outputs": [], "source": [ "image_batch, label_batch = next(iter(pipeline_train))\n", "\n", "plt.figure(figsize=(10, 10))\n", "for n in range(25):\n", " ax = plt.subplot(5, 5, n + 1)\n", " plt.imshow(image_batch[n])\n", " plt.title(label_batch[n].numpy())\n", " plt.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Mp-qVs622Tun" }, "source": [ "## Load pretrained TF-Hub model into a `KerasLayer`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zo_SlVWr2Tuo" }, "outputs": [], "source": [ "bit_model_url = \"https://tfhub.dev/google/bit/m-r50x1/1\"\n", "bit_module = hub.KerasLayer(bit_model_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "gRxcFeya2Tuo" }, "source": [ "## Create BigTransfer (BiT) model\n", "\n", "To create the new model, we:\n", "\n", "1. Cut off the BiT model’s original head. This leaves us with the “pre-logits” output.\n", "We do not have to do this if we use the ‘feature extractor’ models (i.e. all those in\n", "subdirectories titled `feature_vectors`), since for those models the head has already\n", "been cut off.\n", "\n", "2. Add a new head with the number of outputs equal to the number of classes of our new\n", "task. Note that it is important that we initialise the head to all zeroes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9dfabz842Tup" }, "outputs": [], "source": [ "\n", "class MyBiTModel(keras.Model):\n", " def __init__(self, num_classes, module, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.num_classes = num_classes\n", " self.head = keras.layers.Dense(num_classes, kernel_initializer=\"zeros\")\n", " self.bit_model = module\n", "\n", " def call(self, images):\n", " bit_embedding = self.bit_model(images)\n", " return self.head(bit_embedding)\n", "\n", "\n", "model = MyBiTModel(num_classes=NUM_CLASSES, module=bit_module)" ] }, { "cell_type": "markdown", "metadata": { "id": "_CQ4Q8cu2Tup" }, "source": [ "## Define optimizer and loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2AcOBGkb2Tup" }, "outputs": [], "source": [ "learning_rate = 0.003 * BATCH_SIZE / 512\n", "\n", "# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.\n", "lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(\n", " boundaries=SCHEDULE_BOUNDARIES,\n", " values=[\n", " learning_rate,\n", " learning_rate * 0.1,\n", " learning_rate * 0.01,\n", " learning_rate * 0.001,\n", " ],\n", ")\n", "optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)\n", "\n", "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "m7pBODp82Tuq" }, "source": [ "## Compile the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R_FYzSq92Tuq" }, "outputs": [], "source": [ "model.compile(optimizer=optimizer, loss=loss_fn, metrics=[\"accuracy\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "54D9h9F42Tuq" }, "source": [ "## Set up callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UGSXyb-c2Tus" }, "outputs": [], "source": [ "train_callbacks = [\n", " keras.callbacks.EarlyStopping(\n", " monitor=\"val_accuracy\", patience=2, restore_best_weights=True\n", " )\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "7dMcs7yZ2Tus" }, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9B4wTD_u2Tus" }, "outputs": [], "source": [ "history = model.fit(\n", " pipeline_train,\n", " batch_size=BATCH_SIZE,\n", " epochs=int(SCHEDULE_LENGTH / STEPS_PER_EPOCH),\n", " steps_per_epoch=STEPS_PER_EPOCH,\n", " validation_data=pipeline_validation,\n", " callbacks=train_callbacks,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "9x9dUuIE2Tut" }, "source": [ "## Plot the training and validation metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5s8YQTw52Tut" }, "outputs": [], "source": [ "\n", "def plot_hist(hist):\n", " plt.plot(hist.history[\"accuracy\"])\n", " plt.plot(hist.history[\"val_accuracy\"])\n", " plt.plot(hist.history[\"loss\"])\n", " plt.plot(hist.history[\"val_loss\"])\n", " plt.title(\"Training Progress\")\n", " plt.ylabel(\"Accuracy/Loss\")\n", " plt.xlabel(\"Epochs\")\n", " plt.legend([\"train_acc\", \"val_acc\", \"train_loss\", \"val_loss\"], loc=\"upper left\")\n", " plt.show()\n", "\n", "\n", "plot_hist(history)" ] }, { "cell_type": "markdown", "metadata": { "id": "v5M_9ejY2Tut" }, "source": [ "## Evaluate the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yhQdkrrY2Tuu" }, "outputs": [], "source": [ "accuracy = model.evaluate(pipeline_validation)[1] * 100\n", "print(\"Accuracy: {:.2f}%\".format(accuracy))" ] }, { "cell_type": "markdown", "metadata": { "id": "j7ZRc0sD2Tuu" }, "source": [ "## Conclusion\n", "\n", "BiT performs well across a surprisingly wide range of data regimes\n", "-- from 1 example per class to 1M total examples. BiT achieves 87.5% top-1 accuracy on\n", "ILSVRC-2012, 99.4% on CIFAR-10, and 76.3% on the 19 task Visual Task Adaptation Benchmark\n", "(VTAB). BiT achieves 87.5% top-1 accuracy on
ILSVRC-2012, 99.4% on CIFAR-10, and 76.3% on the 19 task Visual Task Adaptation Benchmark
(VTAB). On small datasets, BiT attains 76.8% on ILSVRC-2012 with 10 examples per class,
and 97.0% on CIFAR-10 with 10 examples per class.

![](https://i.imgur.com/b1Lw5fz.png)

You can experiment further with the BigTransfer Method by following the
[original paper](https://arxiv.org/abs/1912.11370).


**Example available on HuggingFace**
| Trained Model | Demo |
| :--: | :--: |
| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-bit-black.svg)](https://huggingface.co/keras-io/bit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-bit-black.svg)](https://huggingface.co/spaces/keras-io/siamese-contrastive) |