{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nEF2_0sn2Tua"
},
"source": [
"# Image Classification using BigTransfer (BiT)\n",
"\n",
"**Author:** [Sayan Nath](https://twitter.com/sayannath2350)
\n",
"**Date created:** 2021/09/24
\n",
"**Last modified:** 2021/09/24
\n",
"**Description:** BigTransfer (BiT) State-of-the-art transfer learning for image classification."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5LUKJJ0l2Tuc"
},
"source": [
"## Introduction\n",
"\n",
"BigTransfer (also known as BiT) is a state-of-the-art transfer learning method for image\n",
"classification. Transfer of pre-trained representations improves sample efficiency and\n",
"simplifies hyperparameter tuning when training deep neural networks for vision. BiT\n",
"revisit the paradigm of pre-training on large supervised datasets and fine-tuning the\n",
"model on a target task. The importance of appropriately choosing normalization layers and\n",
"scaling the architecture capacity as the amount of pre-training data increases.\n",
"\n",
"BigTransfer(BiT) is trained on public datasets, along with code in\n",
"[TF2, Jax and Pytorch](https://github.com/google-research/big_transfer). This will help anyone to reach\n",
"state of the art performance on their task of interest, even with just a handful of\n",
"labeled images per class.\n",
"\n",
"You can find BiT models pre-trained on\n",
"[ImageNet](https://image-net.org/challenges/LSVRC/2012/index) and ImageNet-21k in\n",
"[TFHub](https://tfhub.dev/google/collections/bit/1) as TensorFlow2 SavedModels that you\n",
"can use easily as Keras Layers. There are a variety of sizes ranging from a standard\n",
"ResNet50 to a ResNet152x4 (152 layers deep, 4x wider than a typical ResNet50) for users\n",
"with larger computational and memory budgets but higher accuracy requirements.\n",
"\n",
"![](https://i.imgur.com/XeWVfe7.jpeg)\n",
"Figure: The x-axis shows the number of images used per class, ranging from 1 to the full\n",
"dataset. On the plots on the left, the curve in blue above is our BiT-L model, whereas\n",
"the curve below is a ResNet-50 pre-trained on ImageNet (ILSVRC-2012)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvzPOZc72Tud"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x2bXlydz2Tue"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"\n",
"tfds.disable_progress_bar()\n",
"\n",
"SEEDS = 42\n",
"\n",
"np.random.seed(SEEDS)\n",
"tf.random.set_seed(SEEDS)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iSr9snUg2Tuf"
},
"source": [
"## Gather Flower Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jTf0_ElG2Tuf"
},
"outputs": [],
"source": [
"train_ds, validation_ds = tfds.load(\n",
" \"tf_flowers\",\n",
" split=[\"train[:85%]\", \"train[85%:]\"],\n",
" as_supervised=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RW6zvyF22Tuh"
},
"source": [
"## Visualise the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6W-AiCml2Tui"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for i, (image, label) in enumerate(train_ds.take(9)):\n",
" ax = plt.subplot(3, 3, i + 1)\n",
" plt.imshow(image)\n",
" plt.title(int(label))\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0CsqwOoT2Tui"
},
"source": [
"## Define hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5hy87EV42Tuj"
},
"outputs": [],
"source": [
"RESIZE_TO = 384\n",
"CROP_TO = 224\n",
"BATCH_SIZE = 64\n",
"STEPS_PER_EPOCH = 10\n",
"AUTO = tf.data.AUTOTUNE # optimise the pipeline performance\n",
"NUM_CLASSES = 5 # number of classes\n",
"SCHEDULE_LENGTH = (\n",
" 500 # we will train on lower resolution images and will still attain good results\n",
")\n",
"SCHEDULE_BOUNDARIES = [\n",
" 200,\n",
" 300,\n",
" 400,\n",
"] # more the dataset size the schedule length increase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EV7hIbAh2Tuk"
},
"source": [
"The hyperparamteres like `SCHEDULE_LENGTH` and `SCHEDULE_BOUNDARIES` are determined based\n",
"on empirical results. The method has been explained in the [original\n",
"paper](https://arxiv.org/abs/1912.11370) and in their [Google AI Blog\n",
"Post](https://ai.googleblog.com/2020/05/open-sourcing-bit-exploring-large-scale.html).\n",
"\n",
"The `SCHEDULE_LENGTH` is aslo determined whether to use [MixUp\n",
"Augmentation](https://arxiv.org/abs/1710.09412) or not. You can also find an easy MixUp\n",
"Implementation in [Keras Coding Examples](https://keras.io/examples/vision/mixup/).\n",
"\n",
"![](https://i.imgur.com/oSaIBYZ.jpeg)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fc7_2AS2Tul"
},
"source": [
"## Define preprocessing helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zUQI1SWo2Tum"
},
"outputs": [],
"source": [
"SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE\n",
"\n",
"\n",
"@tf.function\n",
"def preprocess_train(image, label):\n",
" image = tf.image.random_flip_left_right(image)\n",
" image = tf.image.resize(image, (RESIZE_TO, RESIZE_TO))\n",
" image = tf.image.random_crop(image, (CROP_TO, CROP_TO, 3))\n",
" image = image / 255.0\n",
" return (image, label)\n",
"\n",
"\n",
"@tf.function\n",
"def preprocess_test(image, label):\n",
" image = tf.image.resize(image, (RESIZE_TO, RESIZE_TO))\n",
" image = image / 255.0\n",
" return (image, label)\n",
"\n",
"\n",
"DATASET_NUM_TRAIN_EXAMPLES = train_ds.cardinality().numpy()\n",
"\n",
"repeat_count = int(\n",
" SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH\n",
")\n",
"repeat_count += 50 + 1 # To ensure at least there are 50 epochs of training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7b8-1Z7y2Tum"
},
"source": [
"## Define the data pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t8xK0KS12Tum"
},
"outputs": [],
"source": [
"# Training pipeline\n",
"pipeline_train = (\n",
" train_ds.shuffle(10000)\n",
" .repeat(repeat_count) # Repeat dataset_size / num_steps\n",
" .map(preprocess_train, num_parallel_calls=AUTO)\n",
" .batch(BATCH_SIZE)\n",
" .prefetch(AUTO)\n",
")\n",
"\n",
"# Validation pipeline\n",
"pipeline_validation = (\n",
" validation_ds.map(preprocess_test, num_parallel_calls=AUTO)\n",
" .batch(BATCH_SIZE)\n",
" .prefetch(AUTO)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xUMadZeC2Tum"
},
"source": [
"## Visualise the training samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TrILDQhX2Tun"
},
"outputs": [],
"source": [
"image_batch, label_batch = next(iter(pipeline_train))\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"for n in range(25):\n",
" ax = plt.subplot(5, 5, n + 1)\n",
" plt.imshow(image_batch[n])\n",
" plt.title(label_batch[n].numpy())\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mp-qVs622Tun"
},
"source": [
"## Load pretrained TF-Hub model into a `KerasLayer`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zo_SlVWr2Tuo"
},
"outputs": [],
"source": [
"bit_model_url = \"https://tfhub.dev/google/bit/m-r50x1/1\"\n",
"bit_module = hub.KerasLayer(bit_model_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gRxcFeya2Tuo"
},
"source": [
"## Create BigTransfer (BiT) model\n",
"\n",
"To create the new model, we:\n",
"\n",
"1. Cut off the BiT model’s original head. This leaves us with the “pre-logits” output.\n",
"We do not have to do this if we use the ‘feature extractor’ models (i.e. all those in\n",
"subdirectories titled `feature_vectors`), since for those models the head has already\n",
"been cut off.\n",
"\n",
"2. Add a new head with the number of outputs equal to the number of classes of our new\n",
"task. Note that it is important that we initialise the head to all zeroes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9dfabz842Tup"
},
"outputs": [],
"source": [
"\n",
"class MyBiTModel(keras.Model):\n",
" def __init__(self, num_classes, module, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
" self.num_classes = num_classes\n",
" self.head = keras.layers.Dense(num_classes, kernel_initializer=\"zeros\")\n",
" self.bit_model = module\n",
"\n",
" def call(self, images):\n",
" bit_embedding = self.bit_model(images)\n",
" return self.head(bit_embedding)\n",
"\n",
"\n",
"model = MyBiTModel(num_classes=NUM_CLASSES, module=bit_module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CQ4Q8cu2Tup"
},
"source": [
"## Define optimizer and loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2AcOBGkb2Tup"
},
"outputs": [],
"source": [
"learning_rate = 0.003 * BATCH_SIZE / 512\n",
"\n",
"# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.\n",
"lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=SCHEDULE_BOUNDARIES,\n",
" values=[\n",
" learning_rate,\n",
" learning_rate * 0.1,\n",
" learning_rate * 0.01,\n",
" learning_rate * 0.001,\n",
" ],\n",
")\n",
"optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)\n",
"\n",
"loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m7pBODp82Tuq"
},
"source": [
"## Compile the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R_FYzSq92Tuq"
},
"outputs": [],
"source": [
"model.compile(optimizer=optimizer, loss=loss_fn, metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "54D9h9F42Tuq"
},
"source": [
"## Set up callbacks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UGSXyb-c2Tus"
},
"outputs": [],
"source": [
"train_callbacks = [\n",
" keras.callbacks.EarlyStopping(\n",
" monitor=\"val_accuracy\", patience=2, restore_best_weights=True\n",
" )\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7dMcs7yZ2Tus"
},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9B4wTD_u2Tus"
},
"outputs": [],
"source": [
"history = model.fit(\n",
" pipeline_train,\n",
" batch_size=BATCH_SIZE,\n",
" epochs=int(SCHEDULE_LENGTH / STEPS_PER_EPOCH),\n",
" steps_per_epoch=STEPS_PER_EPOCH,\n",
" validation_data=pipeline_validation,\n",
" callbacks=train_callbacks,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9x9dUuIE2Tut"
},
"source": [
"## Plot the training and validation metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5s8YQTw52Tut"
},
"outputs": [],
"source": [
"\n",
"def plot_hist(hist):\n",
" plt.plot(hist.history[\"accuracy\"])\n",
" plt.plot(hist.history[\"val_accuracy\"])\n",
" plt.plot(hist.history[\"loss\"])\n",
" plt.plot(hist.history[\"val_loss\"])\n",
" plt.title(\"Training Progress\")\n",
" plt.ylabel(\"Accuracy/Loss\")\n",
" plt.xlabel(\"Epochs\")\n",
" plt.legend([\"train_acc\", \"val_acc\", \"train_loss\", \"val_loss\"], loc=\"upper left\")\n",
" plt.show()\n",
"\n",
"\n",
"plot_hist(history)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v5M_9ejY2Tut"
},
"source": [
"## Evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yhQdkrrY2Tuu"
},
"outputs": [],
"source": [
"accuracy = model.evaluate(pipeline_validation)[1] * 100\n",
"print(\"Accuracy: {:.2f}%\".format(accuracy))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j7ZRc0sD2Tuu"
},
"source": [
"## Conclusion\n",
"\n",
"BiT performs well across a surprisingly wide range of data regimes\n",
"-- from 1 example per class to 1M total examples. BiT achieves 87.5% top-1 accuracy on\n",
"ILSVRC-2012, 99.4% on CIFAR-10, and 76.3% on the 19 task Visual Task Adaptation Benchmark\n",
"(VTAB). On small datasets, BiT attains 76.8% on ILSVRC-2012 with 10 examples per class,\n",
"and 97.0% on CIFAR-10 with 10 examples per class.\n",
"\n",
"![](https://i.imgur.com/b1Lw5fz.png)\n",
"\n",
"You can experiment further with the BigTransfer Method by following the\n",
"[original paper](https://arxiv.org/abs/1912.11370).\n",
"\n",
"\n",
"**Example available on HuggingFace**\n",
"| Trained Model | Demo |\n",
"| :--: | :--: |\n",
"| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-bit-black.svg)](https://huggingface.co/keras-io/bit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-bit-black.svg)](https://huggingface.co/spaces/keras-io/siamese-contrastive) |"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "bit",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}