# Transfer Learning with skorch

In this tutorial, you will learn how to train a neural network using transfer learning with the `skorch` API. Transfer learning uses a pretrained model to initialize a network. This tutorial converts the pure PyTorch approach described in [PyTorch's Transfer Learning Tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) to `skorch`.

We will be using `torchvision` for this tutorial. Instructions on how to install `torchvision` for your platform can be found at https://pytorch.org.

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Transfer_Learning.ipynb), we recommend you enable a free GPU by going:

> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**

If you are running in colab, you should install the dependencies and download the dataset by running the following cell:

In [1]:
import subprocess

# Installation on Google Colab
try:
    import os
    import google.colab
    subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'torchvision'])
    subprocess.run(['mkdir', '-p', 'datasets'])
    subprocess.run(['wget', '-nc', '--no-check-certificate', 'https://download.pytorch.org/tutorial/hymenoptera_data.zip', '-P', 'datasets'])
    subprocess.run(['unzip', '-u', 'datasets/hymenoptera_data.zip', '-d' 'datasets'])
except ImportError:
    pass

In [2]:
import os
from urllib import request
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms

from skorch import NeuralNetClassifier
from skorch.helper import predefined_split

torch.manual_seed(360);

## Preparations

Before we begin, lets download the data needed for this tutorial:

In [3]:
def download_and_extract_data(dataset_dir='datasets'):
    data_zip = os.path.join(dataset_dir, 'hymenoptera_data.zip')
    data_path = os.path.join(dataset_dir, 'hymenoptera_data')
    url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"

    if not os.path.exists(data_path):
        if not os.path.exists(data_zip):
            print("Starting to download data...")
            data = request.urlopen(url, timeout=15).read()
            with open(data_zip, 'wb') as f:
                f.write(data)

        print("Starting to extract data...")
        with ZipFile(data_zip, 'r') as zip_f:
            zip_f.extractall(dataset_dir)
        
    print("Data has been downloaded and extracted to {}.".format(dataset_dir))
    
download_and_extract_data()

Data has been downloaded and extracted to datasets.


## The Problem

We are going to train a neural network to classify **ants** and **bees**. The dataset consist of 120 training images and 75 validiation images for each class. First we create the training and validiation datasets:

In [4]:
data_dir = 'datasets/hymenoptera_data'
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

train_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'train'), train_transforms)
val_ds = datasets.ImageFolder(
    os.path.join(data_dir, 'val'), val_transforms)

The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: `[0.485, 0.456, 0.406]`, and standard deviation: `[0.229, 0.224, 0.225]`. These values are the means and standard deviations of the ImageNet images. We used these values because the pretrained model was trained on ImageNet.

## Loading pretrained model

We use a pretrained `ResNet18` neural network model with its final layer replaced with a fully connected layer:

In [5]:
class PretrainedModel(nn.Module):
    def __init__(self, output_features):
        super().__init__()
        model = models.resnet18(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, output_features)
        self.model = model
        
    def forward(self, x):
        return self.model(x)

Since we are training a binary classifier, the output of the final fully connected layer has size 2.

## Using skorch's API

In this section, we will create a `skorch.NeuralNetClassifier` to solve our classification problem. 

### Callbacks

First, we create a `LRScheduler` callback which is a learning rate scheduler that uses `torch.optim.lr_scheduler.StepLR` to scale learning rates by `gamma=0.1` every 7 steps:

In [6]:
from skorch.callbacks import LRScheduler

lrscheduler = LRScheduler(
    policy='StepLR', step_size=7, gamma=0.1)

Next, we create a `Checkpoint` callback which saves the best model by by monitoring the validation accuracy. 

In [7]:
from skorch.callbacks import Checkpoint

checkpoint = Checkpoint(
    f_params='best_model.pt', monitor='valid_acc_best')

Lastly, we create a `Freezer` to freeze all weights besides the final layer named `model.fc`:

In [8]:
from skorch.callbacks import Freezer

freezer = Freezer(lambda x: not x.startswith('model.fc'))

### skorch.NeuralNetClassifier

With all the preparations out of the way, we can now define our `NeuralNetClassifier`:

In [9]:
net = NeuralNetClassifier(
    PretrainedModel, 
    criterion=nn.CrossEntropyLoss,
    lr=0.001,
    batch_size=4,
    max_epochs=25,
    module__output_features=2,
    optimizer=optim.SGD,
    optimizer__momentum=0.9,
    iterator_train__shuffle=True,
    iterator_train__num_workers=2,
    iterator_valid__num_workers=2,
    train_split=predefined_split(val_ds),
    callbacks=[lrscheduler, checkpoint, freezer],
    device='cuda' # comment to train on cpu
)

That is quite a few parameters! Lets walk through each one:

1. `model_ft`: Our `ResNet18` neural network
2. `criterion=nn.CrossEntropyLoss`: loss function
3. `lr`: Initial learning rate
4. `batch_size`: Size of a batch
5. `max_epochs`: Number of epochs to train
6. `module__output_features`: Used by `__init__` in our `PretrainedModel` class to set the number of classes.
7. `optimizer`: Our optimizer
8. `optimizer__momentum`: The initial momentum
9. `iterator_{train,valid}__{shuffle,num_workers}`: Parameters that are passed to the dataloader.
10. `train_split`: A wrapper around `val_ds` to use our validation dataset.
11. `callbacks`: Our callbacks 
12. `device`: Set to `cuda` to train on gpu.

Now we are ready to train our neural network:

In [10]:
net.fit(train_ds, y=None);

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  epoch    train_loss    valid_acc    valid_loss    cp      lr     dur
-------  ------------  -----------  ------------  ----  ------  ------
      1        [36m0.6488[0m       [32m0.9477[0m        [35m0.1860[0m     +  0.0010  9.3038
      2        [36m0.4275[0m       0.9412        [35m0.1697[0m        0.0010  3.0520
      3        0.4977       0.9346        0.1728        0.0010  3.1005
      4        0.5072       0.9346        0.1766        0.0010  3.1506
      5        0.5104       [32m0.9608[0m        [35m0.1548[0m     +  0.0010  3.4832
      6        [36m0.3861[0m       0.9216        0.1879        0.0010  3.2256
      7        0.4329       0.9346        0.1839        0.0010  3.0548
      8        [36m0.3634[0m       0.9477        0.1604        0.0001  3.3032
      9        [36m0.3625[0m       0.9477        0.1606        0.0001  3.0581
     10        [36m0.3444[0m       0.9412        0.1796        0.0001  3.0689
     11        [36m0.3334[0m       0.9346      

The best model is stored at `best_model.pt`, with a validiation accuracy of roughly 0.96.

Congrualations! You now know how to finetune a neural network using `skorch`. Feel free to explore the other tutorials to learn more about using `skorch`.