{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "pdcMxVGEA9Cd" }, "source": [ "# **Fine-tuning for Image Classification with đ¤ Transformers**\n", "\n", "This notebook shows how to fine-tune any pretrained Vision model for Image Classification on a custom dataset. The idea is to add a randomly initialized classification head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset.\n", "\n", "## ImageFolder\n", "\n", "This notebook leverages the [ImageFolder](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature to easily run the notebook on a custom dataset (namely, [EuroSAT](https://github.com/phelber/EuroSAT) in this tutorial). You can either load a `Dataset` from local folders or from local/remote files, like zip or tar.\n", "\n", "## Any model\n", "\n", "This notebook is built to run on any image classification dataset with any vision model checkpoint from the [Model Hub](https://huggingface.co/) as long as that model has a TensorFlow version with a Image Classification head, such as:\n", "* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.TFViTForImageClassification)\n", "* [Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin#transformers.TFSwinForImageClassification)\n", "* [ConvNeXT](https://huggingface.co/docs/transformers/master/en/model_doc/convnext#transformers.TFConvNextForImageClassification)\n", "* [RegNet](https://huggingface.co/docs/transformers/master/en/model_doc/regnet#transformers.TFRegNetForImageClassification)\n", "* [ResNet](https://huggingface.co/docs/transformers/master/en/model_doc/resnet#transformers.TFResNetForImageClassification)\n", "\n", "- in short, any model supported by [TFAutoModelForImageClassification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.TFAutoModelForImageClassification).\n", "\n", "## Data augmentation\n", "\n", "This notebook leverages TensorFlow's [image](https://www.tensorflow.org/api_docs/python/tf/image) module for applying data augmentation. Alternative notebooks which leverage other libraries such as [Albumentations](https://albumentations.ai/) to come!\n", "\n", "---\n", "\n", "Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.\n", "\n", "In this notebook, we'll fine-tune from the https://huggingface.co/microsoft/swin-tiny-patch4-window7-224 checkpoint, but note that there are many, many more available on the [hub](https://huggingface.co/models?other=vision)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "5WMEawzyCEyG" }, "outputs": [], "source": [ "model_checkpoint = \"microsoft/swin-tiny-patch4-window7-224\" # pre-trained model from which to fine-tune\n", "batch_size = 32 # batch size for training and evaluation" ] }, { "cell_type": "markdown", "metadata": { "id": "NlArTG8KChJf" }, "source": [ "Before we start, let's install the `datasets` and `transformers` libraries." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "L1532RVbJgQV", "outputId": "1d92a15b-0efd-4b09-b006-56384c64943b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Error parsing requirements for setuptools: [Errno 2] No such file or directory: '/usr/local/lib/python3.8/dist-packages/setuptools-60.6.0.dist-info/METADATA'\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: You are using pip version 22.0.2; however, version 22.1.2 is available.\n", "You should consider upgrading via the '/home/amy/tenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q datasets transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "snZ1tmaOC412" }, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your token:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 386, "referenced_widgets": [ "f1760b8ccf9b4c32977a1e83f3a3af3d", "e276e653c187474dba7e1c4fede10b79", "a6d024d44a6c49eebef7966ef5a836f1", "344c041a34b5458eb1bbb3ea8fa5b315", "18f7dc9f5e6241138534b7cf9aa30adb", "d5e3a1de2a4645639c029a404d04dc1c", "e7e938eb6baf486e829ea5d4734087cf", "730378e114f944908aa06f42bb2faa3d", "f73b5464140d4723b1b3f46796d9b1ca", "16ffe85c44764fa9ad8a31fb21e6432f", "40db3808e98d424cacc5d0fed54b9eaa", "e9bad1a707f0442da6f117fdd2804f72", "a0bac01e342b4793b66d7d4f5bfac2e2", "b447ca05136342d0a167bfb133a353bd", "4dcbcf9b086f452d9d1bc07b4a4cc1d3", "724b578bac3f4b7cb6806ee3c45aff01", "6ca86c47e547426e8a0d4487a786c46c" ] }, "id": "Bkpk_JPlCww8", "outputId": "d80cb8c7-5382-427b-e90b-bfec7afdc052" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ef1c00510764bb88c9d625d6f7c25dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='