{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "sLMSHWM-oiyf" }, "source": [ "# Transfer Learning with skorch" ] }, { "cell_type": "markdown", "metadata": { "id": "2KJVmPqYoiyj" }, "source": [ "In this tutorial, you will learn how to train a neural network using transfer learning with the `skorch` API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in [PyTorch's Transfer Learning Tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) to `skorch`.\n", "\n", "We will be using `torchvision` for this tutorial. Instructions on how to install `torchvision` for your platform can be found at https://pytorch.org.\n", "\n", "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "metadata": { "id": "uFUihT98oiyp" }, "source": [ "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb), we recommend you enable a free GPU by going:\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "i1U_Mi4gon2r" }, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import os\n", " import google.colab\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'torchvision'])\n", " subprocess.run(['mkdir', '-p', 'datasets'])\n", " subprocess.run(['wget', '-nc', '--no-check-certificate', 'https://download.pytorch.org/tutorial/hymenoptera_data.zip', '-P', 'datasets'])\n", " subprocess.run(['unzip', '-u', 'datasets/hymenoptera_data.zip', '-d' 'datasets'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "8EctsAZCoiyy" }, "outputs": [], "source": [ "import os\n", "from urllib import request\n", "from zipfile import ZipFile\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import numpy as np\n", "from torchvision import datasets, models, transforms\n", "\n", "from skorch import NeuralNetClassifier\n", "from skorch.helper import predefined_split\n", "\n", "torch.manual_seed(360);" ] }, { "cell_type": "markdown", "metadata": { "id": "_vaY9ew5oiy1" }, "source": [ "## Preparations" ] }, { "cell_type": "markdown", "metadata": { "id": "Ybj2whoboiy3" }, "source": [ "Before we begin, lets download the data needed for this tutorial:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vXrGCvunoiy4", "outputId": "155c8008-d1e7-4c9f-bbfc-0abbe3d61cbd" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Data has been downloaded and extracted to datasets.\n" ] } ], "source": [ "def download_and_extract_data(dataset_dir='datasets'):\n", " data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')\n", " data_path = os.path.join(dataset_dir, 'hymenoptera_data')\n", " url = \"https://download.pytorch.org/tutorial/hymenoptera_data.zip\"\n", "\n", " if not os.path.exists(data_path):\n", " if not os.path.exists(data_zip):\n", " print(\"Starting to download data...\")\n", " data = request.urlopen(url, timeout=15).read()\n", " with open(data_zip, 'wb') as f:\n", " f.write(data)\n", "\n", " print(\"Starting to extract data...\")\n", " with ZipFile(data_zip, 'r') as zip_f:\n", " zip_f.extractall(dataset_dir)\n", " \n", " print(\"Data has been downloaded and extracted to {}.\".format(dataset_dir))\n", " \n", "download_and_extract_data()" ] }, { "cell_type": "markdown", "metadata": { "id": "V0K5OAKPoiy7" }, "source": [ "## The Problem" ] }, { "cell_type": "markdown", "metadata": { "id": "m52cP6kaoiy8" }, "source": [ "We are going to train a neural network to classify **ants** and **bees**. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "y3lg-xiaoiy9" }, "outputs": [], "source": [ "data_dir = 'datasets/hymenoptera_data'\n", "train_transforms = transforms.Compose([\n", " transforms.RandomResizedCrop(224),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], \n", " [0.229, 0.224, 0.225])\n", "])\n", "val_transforms = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], \n", " [0.229, 0.224, 0.225])\n", "])\n", "\n", "train_ds = datasets.ImageFolder(\n", " os.path.join(data_dir, 'train'), train_transforms)\n", "val_ds = datasets.ImageFolder(\n", " os.path.join(data_dir, 'val'), val_transforms)" ] }, { "cell_type": "markdown", "metadata": { "id": "qfbGs4wfoiy_" }, "source": [ "The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: `[0.485, 0.456, 0.406]`, and standard deviation: `[0.229, 0.224, 0.225]`. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet." ] }, { "cell_type": "markdown", "metadata": { "id": "APIEJMARoizA" }, "source": [ "## Loading pretrained model" ] }, { "cell_type": "markdown", "metadata": { "id": "TUNjvYdwoizA" }, "source": [ "We use a pretrained `ResNet18` neural network model with its final layer replaced with a fully connected layer:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "cegvloOPoizB" }, "outputs": [], "source": [ "class PretrainedModel(nn.Module):\n", " def __init__(self, output_features):\n", " super().__init__()\n", " model = models.resnet18(pretrained=True)\n", " num_ftrs = model.fc.in_features\n", " model.fc = nn.Linear(num_ftrs, output_features)\n", " self.model = model\n", " \n", " def forward(self, x):\n", " return self.model(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "al8ehonloizD" }, "source": [ "Since we are training a binary classifier, the output of the final fully connected layer has size 2." ] }, { "cell_type": "markdown", "metadata": { "id": "vNmUTqT6oizD" }, "source": [ "## Using skorch's API" ] }, { "cell_type": "markdown", "metadata": { "id": "j2hJ-CUUoizE" }, "source": [ "In this section, we will create a `skorch.NeuralNetClassifier` to solve our classification problem. " ] }, { "cell_type": "markdown", "metadata": { "id": "ob1FqKNloizF" }, "source": [ "### Callbacks" ] }, { "cell_type": "markdown", "metadata": { "id": "aHVRYecRoizF" }, "source": [ "First, we create a `LRScheduler` callback which is a learning rate scheduler that uses `torch.optim.lr_scheduler.StepLR` to scale learning rates by `gamma=0.1` every 7 steps:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "ZJEUie2VoizG" }, "outputs": [], "source": [ "from skorch.callbacks import LRScheduler\n", "\n", "lrscheduler = LRScheduler(\n", " policy='StepLR', step_size=7, gamma=0.1)" ] }, { "cell_type": "markdown", "metadata": { "id": "Bli6ngl5oizG" }, "source": [ "Next, we create a `Checkpoint` callback which saves the best model by by monitoring the validation accuracy. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "rGPmqqhWoizH" }, "outputs": [], "source": [ "from skorch.callbacks import Checkpoint\n", "\n", "checkpoint = Checkpoint(\n", " f_params='best_model.pt', monitor='valid_acc_best')" ] }, { "cell_type": "markdown", "metadata": { "id": "XQS39O_GoizJ" }, "source": [ "Lastly, we create a `Freezer` to freeze all weights besides the final layer named `model.fc`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "kXOe41FBoizJ" }, "outputs": [], "source": [ "from skorch.callbacks import Freezer\n", "\n", "freezer = Freezer(lambda x: not x.startswith('model.fc'))" ] }, { "cell_type": "markdown", "metadata": { "id": "4_JqX5wEoizK" }, "source": [ "### skorch.NeuralNetClassifier" ] }, { "cell_type": "markdown", "metadata": { "id": "2ahqnHCZoizL" }, "source": [ "With all the preparations out of the way, we can now define our `NeuralNetClassifier`:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "xQCwRTaZoizL" }, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " PretrainedModel, \n", " criterion=nn.CrossEntropyLoss,\n", " lr=0.001,\n", " batch_size=4,\n", " max_epochs=25,\n", " module__output_features=2,\n", " optimizer=optim.SGD,\n", " optimizer__momentum=0.9,\n", " iterator_train__shuffle=True,\n", " iterator_train__num_workers=2,\n", " iterator_valid__num_workers=2,\n", " train_split=predefined_split(val_ds),\n", " callbacks=[lrscheduler, checkpoint, freezer],\n", " device='cuda' # comment to train on cpu\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ZggANL3DoizL" }, "source": [ "That is quite a few parameters! Lets walk through each one:\n", "\n", "1. `model_ft`: Our `ResNet18` neural network\n", "2. `criterion=nn.CrossEntropyLoss`: loss function\n", "3. `lr`: Initial learning rate\n", "4. `batch_size`: Size of a batch\n", "5. `max_epochs`: Number of epochs to train\n", "6. `module__output_features`: Used by `__init__` in our `PretrainedModel` class to set the number of classes.\n", "7. `optimizer`: Our optimizer\n", "8. `optimizer__momentum`: The initial momentum\n", "9. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.\n", "10. `train_split`: A wrapper around `val_ds` to use our validation dataset.\n", "11. `callbacks`: Our callbacks \n", "12. `device`: Set to `cuda` to train on gpu.\n", "\n", "Now we are ready to train our neural network:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 641, "referenced_widgets": [ "2e427ec51ada45dbb95b3a8d432aa83a", "0f274dff5dd34ab3baaf9158c7ce5de5", "01efc735f0d34e478dc2820eb212f108", "b088563fd5f24fd39516b4999fdfc003", "bd612f740e534feeaca87c8d44825569", "12e62a72458c49fabaf223dcc8c7e32b", "05c0e1e9e6a443c8b5443b7beb34afb2", "8e1c5fb23d074fa68f607f1ebcadb289", "5e9cecf6e979425292cf7703b8822183", "1c35ef4c3d72459f84c2354fe57c61a9", "bd92813035144bca9fa48f555f8d218a" ] }, "id": "sYjoTUdyoizM", "outputId": "f88743c0-cf16-4037-fa84-a66502d0f2fb" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", "/usr/local/lib/python3.8/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n", " warnings.warn(msg)\n", "Downloading: \"https://download.pytorch.org/models/resnet18-f37072fd.pth\" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ " 0%| | 0.00/44.7M [00:00