{ "cells": [ { "cell_type": "markdown", "id": "c7923002", "metadata": {}, "source": [ "# Demonstration of using Hugging Face vision transformer with skorch" ] }, { "cell_type": "markdown", "id": "7f915413", "metadata": {}, "source": [ "Based on this [blog post](https://huggingface.co/blog/fine-tune-vit), we show how with only a few lines of custom code, we can fine-tune a Vision Transformer model for a classification task." ] }, { "cell_type": "markdown", "id": "5ad2395e", "metadata": {}, "source": [ "In addition to installing torch and skorch, you need the `transformers` and `datasets` libraries:\n", "\n", " $ python -m pip install transformers datasets" ] }, { "cell_type": "markdown", "id": "830b6c67", "metadata": {}, "source": [ "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "code", "execution_count": 1, "id": "30687543", "metadata": {}, "outputs": [], "source": [ "! [ ! -z \"$COLAB_GPU\" ] && pip install torch skorch transformers datasets" ] }, { "cell_type": "markdown", "id": "78c87b90", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "id": "fc22f469", "metadata": {}, "outputs": [], "source": [ "from functools import partial" ] }, { "cell_type": "code", "execution_count": 3, "id": "bd8bb4c3", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from datasets import load_dataset\n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.pipeline import Pipeline\n", "from skorch import NeuralNetClassifier\n", "from skorch.callbacks import ProgressBar, LRScheduler\n", "from torch import nn\n", "from torch.optim.lr_scheduler import LambdaLR\n", "from transformers import ViTFeatureExtractor, ViTForImageClassification" ] }, { "cell_type": "code", "execution_count": 4, "id": "07d458ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(1234)" ] }, { "cell_type": "markdown", "id": "70e830d2", "metadata": {}, "source": [ "## Load beans dataset" ] }, { "cell_type": "markdown", "id": "7bb492c5", "metadata": {}, "source": [ "More details on the dataset can be found on [its datasets page](https://huggingface.co/datasets/beans). For our purposes, what's important is that we have image inputs and the target we're trying to predict is one of three classes for each image." ] }, { "cell_type": "code", "execution_count": 5, "id": "f061f7d8", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset beans (/home/vinh/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "925060bd008e459ca47c6184f553332a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds['train'][0]['image']" ] }, { "cell_type": "markdown", "id": "9fc156f4", "metadata": {}, "source": [ "## Custom code" ] }, { "cell_type": "markdown", "id": "a46024c3", "metadata": {}, "source": [ "We wrap the vision transformer feature extractor into an sklearn `Transformer`. It doesn't do much more than loading the feature extractor and returning the pixel values of the features. It also takes care of setting the device.\n", "\n", "The reason to have a separate step for the feature extractor is that it needs to be called on the images only once, given that the output is deterministic. If we would put it inside the `nn.Module`, we would call it on the same data once per epoch, which is wasteful." ] }, { "cell_type": "code", "execution_count": 8, "id": "d615c114", "metadata": {}, "outputs": [], "source": [ "class FeatureExtractor(BaseEstimator, TransformerMixin):\n", " def __init__(self, model_name, device='cpu'):\n", " self.model_name = model_name\n", " self.device = device\n", "\n", " def fit(self, X, y=None, **fit_params):\n", " self.extractor_ = ViTFeatureExtractor.from_pretrained(\n", " self.model_name, device=self.device,\n", " )\n", " return self\n", "\n", " def transform(self, X):\n", " return self.extractor_(X, return_tensors='pt')['pixel_values']" ] }, { "cell_type": "markdown", "id": "01889021", "metadata": {}, "source": [ "The vision transformer module itself is modified to return the logits." ] }, { "cell_type": "code", "execution_count": 9, "id": "1185fec9", "metadata": {}, "outputs": [], "source": [ "class VitModule(nn.Module):\n", " def __init__(self, model_name, num_classes):\n", " super().__init__()\n", " self.model = ViTForImageClassification.from_pretrained(\n", " model_name, num_labels=num_classes\n", " )\n", "\n", " def forward(self, X):\n", " X = self.model(X)\n", " return X.logits" ] }, { "cell_type": "markdown", "id": "c007b8bc", "metadata": {}, "source": [ "To stick close to the original blog post, we use the same learning rate schedule." ] }, { "cell_type": "code", "execution_count": 10, "id": "e8b81ea1", "metadata": {}, "outputs": [], "source": [ "def lr_lambda(current_step: int, num_warmup_steps, num_training_steps):\n", " if current_step < num_warmup_steps:\n", " return float(current_step) / float(max(1, num_warmup_steps))\n", " return max(\n", " 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n", " )" ] }, { "cell_type": "markdown", "id": "d5b72951", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "id": "a32d7a6c", "metadata": {}, "source": [ "### hyper parameters" ] }, { "cell_type": "code", "execution_count": 11, "id": "afcd8260", "metadata": {}, "outputs": [], "source": [ "vit_model = 'google/vit-base-patch32-224-in21k'\n", "max_epochs = 4\n", "batch_size = 16\n", "optimizer = torch.optim.AdamW\n", "learning_rate = 2e-4\n", "weight_decay = 0.0\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "lr_lambda_schedule = partial(lr_lambda, num_warmup_steps=0.0, num_training_steps=max_epochs)" ] }, { "cell_type": "markdown", "id": "3673b56f", "metadata": {}, "source": [ "### model pipeline" ] }, { "cell_type": "markdown", "id": "47c30a05", "metadata": {}, "source": [ "The model definition is straightforward. We use an sklearn `Pipeline` to chain the feature extractor and the model together. The model itself is a skorch `NeuralNetClassifier`, because we're dealing with a classification task. The module we need to pass to `NeuralNetClassifier` is the `VitModule` we defined above.\n", "\n", "As always in skorch, to pass sub-parameters, we use the double-underscore notation. So e.g. to pass the number of classes argument, `num_classes`, to the module, we set `module__num_classes=3`.\n", "\n", "The arguments used here are all fairly standard. Note that we use the `LRScheduler` callback from skorch to use the aforementioned learning rate schedule, and we add the `ProgressBar` callback too, which, as the name suggests, adds a progressbar.\n", "\n", "To stick close to the blog post, we also set `train_split=False`, so that skorch uses the whole training data for training. By default, we would instead split a part of the training data for internal validation. But since the dataset already defines a validation split, this is not necessary." ] }, { "cell_type": "code", "execution_count": 12, "id": "2e3b9d5b", "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([\n", " ('feature_extractor', FeatureExtractor(\n", " vit_model,\n", " device=device,\n", " )),\n", " ('net', NeuralNetClassifier(\n", " VitModule,\n", " module__model_name=vit_model,\n", " module__num_classes=3,\n", " criterion=nn.CrossEntropyLoss,\n", " max_epochs=max_epochs,\n", " batch_size=batch_size,\n", " optimizer=optimizer,\n", " optimizer__weight_decay=weight_decay,\n", " lr=learning_rate,\n", " device=device,\n", " iterator_train__shuffle=True,\n", " train_split=False,\n", " callbacks=[\n", " LRScheduler(LambdaLR, lr_lambda=lr_lambda_schedule),\n", " ProgressBar(),\n", " ],\n", " )),\n", "])" ] }, { "cell_type": "markdown", "id": "b7107df9", "metadata": {}, "source": [ "Now we're ready to train the model. As always, we just call `fit` with our training data. skorch will automatically show the progress bar and some training metrics like the train loss." ] }, { "cell_type": "code", "execution_count": 13, "id": "9097c093", "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9a97c6ef078e4a3b928edf9798b72597", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/160 [00:00