{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "V920LTuiq40d" }, "source": [ "# Basic usage" ] }, { "cell_type": "markdown", "metadata": { "id": "VYNcdzMLq40j" }, "source": [ "*`skorch`* is designed to maximize interoperability between `sklearn` and `pytorch`. The aim is to keep 99% of the flexibility of `pytorch` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `skorch` and how it can be combined with `sklearn`.\n", "\n", "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "metadata": { "id": "EOEEblzDq40m" }, "source": [ "This notebook shows you how to use the basic functionality of `skorch`." ] }, { "cell_type": "markdown", "metadata": { "id": "fpiUBxsUq40o" }, "source": [ "### Table of contents" ] }, { "cell_type": "markdown", "metadata": { "id": "pevxrQfSq40q" }, "source": [ "* [Definition of the pytorch module](#Definition-of-the-pytorch-module)\n", "* [Training a classifier](#Training-a-classifier-and-making-predictions)\n", " * [Dataset](#A-toy-binary-classification-task)\n", " * [pytorch module](#Definition-of-the-pytorch-classification-module)\n", " * [Model training](#Defining-and-training-the-neural-net-classifier)\n", " * [Inference](#Making-predictions,-classification)\n", "* [Training a regressor](#Training-a-regressor)\n", " * [Dataset](#A-toy-regression-task)\n", " * [pytorch module](#Definition-of-the-pytorch-regression-module)\n", " * [Model training](#Defining-and-training-the-neural-net-regressor)\n", " * [Inference](#Making-predictions,-regression)\n", "* [Saving and loading a model](#Saving-and-loading-a-model)\n", " * [Whole model](#Saving-the-whole-model)\n", " * [Only parameters](#Saving-only-the-model-parameters)\n", "* [Usage with an sklearn Pipeline](#Usage-with-an-sklearn-Pipeline)\n", "* [Callbacks](#Callbacks)\n", "* [Grid search](#Usage-with-sklearn-GridSearchCV)\n", " * [Special prefixes](#Special-prefixes)\n", " * [Performing a grid search](#Performing-a-grid-search)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "utYcb97jq40t" }, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import google.colab\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "WZ3Y_KHvq40x" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "D9d6X0ZZq40z" }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "torch.cuda.manual_seed(0)" ] }, { "cell_type": "markdown", "metadata": { "id": "dAnY8yaDq400" }, "source": [ "## Training a classifier and making predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "nKHxWMzKq401" }, "source": [ "### A toy binary classification task" ] }, { "cell_type": "markdown", "metadata": { "id": "jpsnS1HDq403" }, "source": [ "We load a toy classification task from `sklearn`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "H55IvQdyq403" }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import make_classification" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "CzJ18ICiq404" }, "outputs": [], "source": [ "# This is a toy dataset for binary classification, 1000 data points with 20 features each\n", "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n", "X, y = X.astype(np.float32), y.astype(np.int64)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true, "id": "2DdtOvQuq406", "outputId": "280ae85a-fd8f-4e82-e877-c7abef6415ee", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((1000, 20), (1000,), 0.5)" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "X.shape, y.shape, y.mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "0w2mm41yq407" }, "source": [ "### Definition of the `pytorch` classification `module`" ] }, { "cell_type": "markdown", "metadata": { "id": "oMFoiitJq407" }, "source": [ "We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "B7eNyYKzq408" }, "outputs": [], "source": [ "class ClassifierModule(nn.Module):\n", " def __init__(\n", " self,\n", " num_units=10,\n", " nonlin=F.relu,\n", " dropout=0.5,\n", " ):\n", " super(ClassifierModule, self).__init__()\n", " self.num_units = num_units\n", " self.nonlin = nonlin\n", " self.dropout = dropout\n", "\n", " self.dense0 = nn.Linear(20, num_units)\n", " self.nonlin = nonlin\n", " self.dropout = nn.Dropout(dropout)\n", " self.dense1 = nn.Linear(num_units, 10)\n", " self.output = nn.Linear(10, 2)\n", "\n", " def forward(self, X, **kwargs):\n", " X = self.nonlin(self.dense0(X))\n", " X = self.dropout(X)\n", " X = F.relu(self.dense1(X))\n", " X = F.softmax(self.output(X), dim=-1)\n", " return X" ] }, { "cell_type": "markdown", "metadata": { "id": "qN9hYiWvq409" }, "source": [ "### Defining and training the neural net classifier" ] }, { "cell_type": "markdown", "metadata": { "id": "-CqI683Wq40-" }, "source": [ "We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional.\n", "\n", "*Note*: To use the CUDA backend, pass `device='cuda'` as an additional argument." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "14HVaJZDq40-" }, "outputs": [], "source": [ "from skorch import NeuralNetClassifier" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "vltXCNfgq40_" }, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", "# device='cuda', # uncomment this to train with CUDA\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "mNxs9BRJq41A" }, "source": [ "As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": false, "id": "juu_iujpq41B", "outputId": "ed85d9f9-b0f1-4c2f-a8bc-bf6722e54326", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.6905\u001b[0m \u001b[32m0.6150\u001b[0m \u001b[35m0.6749\u001b[0m 0.2269\n", " 2 \u001b[36m0.6740\u001b[0m \u001b[32m0.6200\u001b[0m \u001b[35m0.6668\u001b[0m 0.0308\n", " 3 \u001b[36m0.6594\u001b[0m \u001b[32m0.6750\u001b[0m \u001b[35m0.6554\u001b[0m 0.0185\n", " 4 \u001b[36m0.6482\u001b[0m \u001b[32m0.6900\u001b[0m \u001b[35m0.6452\u001b[0m 0.0183\n", " 5 \u001b[36m0.6423\u001b[0m \u001b[32m0.7050\u001b[0m \u001b[35m0.6333\u001b[0m 0.0189\n", " 6 \u001b[36m0.6231\u001b[0m 0.7000 \u001b[35m0.6188\u001b[0m 0.0198\n", " 7 \u001b[36m0.6081\u001b[0m \u001b[32m0.7100\u001b[0m \u001b[35m0.6064\u001b[0m 0.0199\n", " 8 \u001b[36m0.6003\u001b[0m 0.7000 \u001b[35m0.5940\u001b[0m 0.0183\n", " 9 \u001b[36m0.5937\u001b[0m \u001b[32m0.7250\u001b[0m \u001b[35m0.5836\u001b[0m 0.0216\n", " 10 \u001b[36m0.5830\u001b[0m 0.7150 \u001b[35m0.5725\u001b[0m 0.0194\n", " 11 \u001b[36m0.5686\u001b[0m 0.7100 \u001b[35m0.5660\u001b[0m 0.0202\n", " 12 0.5701 0.7150 \u001b[35m0.5577\u001b[0m 0.0200\n", " 13 0.5751 0.7200 \u001b[35m0.5499\u001b[0m 0.0201\n", " 14 \u001b[36m0.5662\u001b[0m 0.7250 \u001b[35m0.5438\u001b[0m 0.0167\n", " 15 \u001b[36m0.5422\u001b[0m 0.7250 \u001b[35m0.5354\u001b[0m 0.0204\n", " 16 \u001b[36m0.5363\u001b[0m 0.7250 \u001b[35m0.5310\u001b[0m 0.0268\n", " 17 0.5378 \u001b[32m0.7350\u001b[0m \u001b[35m0.5263\u001b[0m 0.0281\n", " 18 0.5517 0.7250 0.5276 0.0193\n", " 19 0.5448 0.7250 0.5291 0.0182\n", " 20 \u001b[36m0.5280\u001b[0m 0.7350 \u001b[35m0.5247\u001b[0m 0.0219\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", ")" ] }, "metadata": {}, "execution_count": 12 } ], "source": [ "# Training the network\n", "net.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "keJT0oM2q41C" }, "source": [ "Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model." ] }, { "cell_type": "markdown", "metadata": { "id": "rohwh7Ugq41C" }, "source": [ "### Making predictions, classification" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "mxWtMQCtq41D", "outputId": "06030606-bc4e-4396-80ec-d5724af50077", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, 0, 0])" ] }, "metadata": {}, "execution_count": 13 } ], "source": [ "# Making prediction for first 5 data points of X\n", "y_pred = net.predict(X[:5])\n", "y_pred" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "IWqQm6WXq41E", "outputId": "59e7af8c-6ba5-432b-f874-9e65ce0d1789", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[0.5603605 , 0.43963954],\n", " [0.782588 , 0.21741195],\n", " [0.6924924 , 0.30750763],\n", " [0.8895971 , 0.1104029 ],\n", " [0.7074626 , 0.2925373 ]], dtype=float32)" ] }, "metadata": {}, "execution_count": 14 } ], "source": [ "# Checking probarbility of each class for first 5 data points of X\n", "y_proba = net.predict_proba(X[:5])\n", "y_proba" ] }, { "cell_type": "markdown", "metadata": { "id": "eX9F9i9Nq41E" }, "source": [ "## Training a regressor" ] }, { "cell_type": "markdown", "metadata": { "id": "z-WnqycIq41F" }, "source": [ "### A toy regression task" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "FtDMYd7sq41F" }, "outputs": [], "source": [ "from sklearn.datasets import make_regression" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "OfHPKjK3q41G" }, "outputs": [], "source": [ "# This is a toy dataset for regression, 1000 data points with 20 features each\n", "X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)\n", "X_regr = X_regr.astype(np.float32)\n", "y_regr = y_regr.astype(np.float32) / 100\n", "y_regr = y_regr.reshape(-1, 1)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true, "id": "35aMo_r2q41G", "outputId": "657fa561-a867-49b1-ed51-eea21d183a84", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "((1000, 20), (1000, 1), -6.4901485, 6.154505)" ] }, "metadata": {}, "execution_count": 17 } ], "source": [ "X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()" ] }, { "cell_type": "markdown", "metadata": { "id": "P3iQWXCbq41H" }, "source": [ "*Note*: Regression requires the target to be 2-dimensional, hence the need to reshape. " ] }, { "cell_type": "markdown", "metadata": { "id": "Xl62UaRiq41H" }, "source": [ "### Definition of the `pytorch` regression `module`" ] }, { "cell_type": "markdown", "metadata": { "id": "l1FTciDpq41I" }, "source": [ "Again, define a vanilla neural network with two hidden layers. The main difference is that the output layer only has one unit and does not apply a softmax nonlinearity." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "rI0pjTzVq41J" }, "outputs": [], "source": [ "class RegressorModule(nn.Module):\n", " def __init__(\n", " self,\n", " num_units=10,\n", " nonlin=F.relu,\n", " ):\n", " super(RegressorModule, self).__init__()\n", " self.num_units = num_units\n", " self.nonlin = nonlin\n", "\n", " self.dense0 = nn.Linear(20, num_units)\n", " self.nonlin = nonlin\n", " self.dense1 = nn.Linear(num_units, 10)\n", " self.output = nn.Linear(10, 1)\n", "\n", " def forward(self, X, **kwargs):\n", " X = self.nonlin(self.dense0(X))\n", " X = F.relu(self.dense1(X))\n", " X = self.output(X)\n", " return X" ] }, { "cell_type": "markdown", "metadata": { "id": "DGbddgBCq41K" }, "source": [ "### Defining and training the neural net regressor" ] }, { "cell_type": "markdown", "metadata": { "id": "_tYQwfKgq41K" }, "source": [ "Training a regressor is almost the same as training a classifier. Mainly, we use `NeuralNetRegressor` instead of `NeuralNetClassifier` (this is the same terminology as in `sklearn`)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "UYPYtJAIq41M" }, "outputs": [], "source": [ "from skorch import NeuralNetRegressor" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "GfSzKX-2q41N" }, "outputs": [], "source": [ "net_regr = NeuralNetRegressor(\n", " RegressorModule,\n", " max_epochs=20,\n", " lr=0.1,\n", "# device='cuda', # uncomment this to train with CUDA\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "yIoepl_bq41O", "outputId": "02987b06-1233-4103-bea3-8af6a7525fd7", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_loss dur\n", "------- ------------ ------------ ------\n", " 1 \u001b[36m4.4794\u001b[0m \u001b[32m3.4054\u001b[0m 0.0327\n", " 2 \u001b[36m2.6630\u001b[0m \u001b[32m0.5670\u001b[0m 0.0201\n", " 3 \u001b[36m0.3102\u001b[0m \u001b[32m0.2004\u001b[0m 0.0278\n", " 4 \u001b[36m0.2250\u001b[0m 0.5211 0.0184\n", " 5 0.3249 \u001b[32m0.1675\u001b[0m 0.0190\n", " 6 \u001b[36m0.1716\u001b[0m 0.2142 0.0201\n", " 7 \u001b[36m0.1175\u001b[0m \u001b[32m0.1192\u001b[0m 0.0222\n", " 8 0.1917 0.2653 0.0210\n", " 9 0.1455 0.1196 0.0212\n", " 10 \u001b[36m0.1144\u001b[0m \u001b[32m0.0803\u001b[0m 0.0228\n", " 11 \u001b[36m0.0434\u001b[0m \u001b[32m0.0712\u001b[0m 0.0218\n", " 12 0.0819 \u001b[32m0.0690\u001b[0m 0.0211\n", " 13 \u001b[36m0.0419\u001b[0m 0.0737 0.0219\n", " 14 0.0748 \u001b[32m0.0498\u001b[0m 0.0222\n", " 15 \u001b[36m0.0310\u001b[0m 0.0586 0.0217\n", " 16 0.0522 \u001b[32m0.0312\u001b[0m 0.0263\n", " 17 \u001b[36m0.0189\u001b[0m 0.0419 0.0223\n", " 18 0.0357 \u001b[32m0.0219\u001b[0m 0.0204\n", " 19 \u001b[36m0.0134\u001b[0m 0.0345 0.0215\n", " 20 0.0277 \u001b[32m0.0161\u001b[0m 0.0203\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[initialized](\n", " module_=RegressorModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=1, bias=True)\n", " ),\n", ")" ] }, "metadata": {}, "execution_count": 21 } ], "source": [ "net_regr.fit(X_regr, y_regr)" ] }, { "cell_type": "markdown", "metadata": { "id": "TSG_7WKZq41O" }, "source": [ "### Making predictions, regression" ] }, { "cell_type": "markdown", "metadata": { "id": "xseQDPkZq41P" }, "source": [ "You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "EMh7DoB1q41P", "outputId": "14618cc8-0cfb-4f48-a932-f04a93fbba45", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[ 0.62908685],\n", " [-1.5245112 ],\n", " [-0.48306593],\n", " [-0.27282855],\n", " [-0.42769447]], dtype=float32)" ] }, "metadata": {}, "execution_count": 22 } ], "source": [ "# Making prediction for first 5 data points of X\n", "y_pred = net_regr.predict(X_regr[:5])\n", "y_pred" ] }, { "cell_type": "markdown", "metadata": { "id": "h2Z7LBzaq41P" }, "source": [ "## Saving and loading a model" ] }, { "cell_type": "markdown", "metadata": { "id": "sDdnm_6mq41Q" }, "source": [ "Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`." ] }, { "cell_type": "markdown", "metadata": { "id": "gZPKIAI4q41Q" }, "source": [ "### Saving the whole model" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "8CwJ2on5q41Q" }, "outputs": [], "source": [ "import pickle" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "tqtrU30vq41Q" }, "outputs": [], "source": [ "file_name = '/tmp/mymodel.pkl'" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "LdC37Q18q41R" }, "outputs": [], "source": [ "with open(file_name, 'wb') as f:\n", " pickle.dump(net, f)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "MXcDEmrMq41R" }, "outputs": [], "source": [ "with open(file_name, 'rb') as f:\n", " new_net = pickle.load(f)" ] }, { "cell_type": "markdown", "metadata": { "id": "P8QcpcM2q41S" }, "source": [ "### Saving only the model parameters" ] }, { "cell_type": "markdown", "metadata": { "id": "R0UDJ4zPq41T" }, "source": [ "This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "4ob1VxTuq41W" }, "outputs": [], "source": [ "net.save_params(f_params=file_name) # a file handler also works" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "ji8LnbI8q41W" }, "outputs": [], "source": [ "# first initialize the model\n", "new_net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", ").initialize()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "2jjSZz10q41W" }, "outputs": [], "source": [ "new_net.load_params(file_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "eF9LrDtxq41X" }, "source": [ "## Usage with an `sklearn Pipeline`" ] }, { "cell_type": "markdown", "metadata": { "id": "Cw56Sra0q41X" }, "source": [ "It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "ernAgNliq41X" }, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "dQuWgYgcq41X" }, "outputs": [], "source": [ "pipe = Pipeline([\n", " ('scale', StandardScaler()),\n", " ('net', net),\n", "])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "-85dViwCq41Y", "outputId": "7747e985-77af-4053-e31a-fd09c4813b01", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Re-initializing module.\n", "Re-initializing criterion.\n", "Re-initializing optimizer.\n", " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.7188\u001b[0m \u001b[32m0.4350\u001b[0m \u001b[35m0.7027\u001b[0m 0.0177\n", " 2 \u001b[36m0.6989\u001b[0m \u001b[32m0.4500\u001b[0m \u001b[35m0.6988\u001b[0m 0.0184\n", " 3 0.7031 \u001b[32m0.4750\u001b[0m \u001b[35m0.6956\u001b[0m 0.0254\n", " 4 \u001b[36m0.6956\u001b[0m \u001b[32m0.5250\u001b[0m \u001b[35m0.6931\u001b[0m 0.0204\n", " 5 \u001b[36m0.6892\u001b[0m 0.5250 \u001b[35m0.6912\u001b[0m 0.0196\n", " 6 0.6905 \u001b[32m0.5300\u001b[0m \u001b[35m0.6890\u001b[0m 0.0233\n", " 7 \u001b[36m0.6888\u001b[0m \u001b[32m0.5400\u001b[0m \u001b[35m0.6866\u001b[0m 0.0302\n", " 8 \u001b[36m0.6842\u001b[0m \u001b[32m0.5700\u001b[0m \u001b[35m0.6842\u001b[0m 0.0293\n", " 9 \u001b[36m0.6815\u001b[0m \u001b[32m0.5950\u001b[0m \u001b[35m0.6815\u001b[0m 0.0193\n", " 10 \u001b[36m0.6761\u001b[0m 0.5900 \u001b[35m0.6787\u001b[0m 0.0200\n", " 11 0.6777 0.5850 \u001b[35m0.6761\u001b[0m 0.0195\n", " 12 \u001b[36m0.6677\u001b[0m \u001b[32m0.6050\u001b[0m \u001b[35m0.6730\u001b[0m 0.0207\n", " 13 \u001b[36m0.6646\u001b[0m \u001b[32m0.6250\u001b[0m \u001b[35m0.6695\u001b[0m 0.0188\n", " 14 \u001b[36m0.6620\u001b[0m \u001b[32m0.6350\u001b[0m \u001b[35m0.6647\u001b[0m 0.0195\n", " 15 \u001b[36m0.6560\u001b[0m 0.6350 \u001b[35m0.6604\u001b[0m 0.0230\n", " 16 \u001b[36m0.6478\u001b[0m \u001b[32m0.6400\u001b[0m \u001b[35m0.6562\u001b[0m 0.0208\n", " 17 0.6499 0.6400 \u001b[35m0.6519\u001b[0m 0.0198\n", " 18 \u001b[36m0.6275\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6459\u001b[0m 0.0216\n", " 19 0.6342 \u001b[32m0.6550\u001b[0m \u001b[35m0.6412\u001b[0m 0.0230\n", " 20 \u001b[36m0.6192\u001b[0m 0.6550 \u001b[35m0.6356\u001b[0m 0.0221\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "Pipeline(steps=[('scale', StandardScaler()),\n", " ('net',\n", " [initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", "))])" ] }, "metadata": {}, "execution_count": 32 } ], "source": [ "pipe.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "kdp1dLugq41Y", "outputId": "10fc1a21-1c15-42af-cf1e-d9065fad963e", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([[0.45305932, 0.5469407 ],\n", " [0.6833223 , 0.3166777 ],\n", " [0.7487094 , 0.2512906 ],\n", " [0.7011023 , 0.29889768],\n", " [0.7423797 , 0.25762028]], dtype=float32)" ] }, "metadata": {}, "execution_count": 33 } ], "source": [ "y_proba = pipe.predict_proba(X[:5])\n", "y_proba" ] }, { "cell_type": "markdown", "metadata": { "id": "tGZJIOCXq41Z" }, "source": [ "To save the whole pipeline, including the pytorch module, use `pickle`." ] }, { "cell_type": "markdown", "metadata": { "id": "GlFZurPUq41Z" }, "source": [ "## Callbacks" ] }, { "cell_type": "markdown", "metadata": { "id": "X21MVEqtq41Z" }, "source": [ "Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score." ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "vxHK8z4Zq41a" }, "outputs": [], "source": [ "from skorch.callbacks import EpochScoring" ] }, { "cell_type": "markdown", "metadata": { "id": "GrV9YE4sq41b" }, "source": [ "There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate. We have 3 choices:\n", "\n", "* Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics).\n", "* Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that.\n", "* Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used.\n", "\n", "Note that this works exactly the same as scoring in `sklearn` does." ] }, { "cell_type": "markdown", "metadata": { "id": "YM-ImwJuq41c" }, "source": [ "For our case here, since `sklearn` already implements AUC, we just pass the correct string `'roc_auc'`. We should also tell the callback that higher scores are better (to get the correct colors printed below -- by default, lower scores are assumed to be better). Furthermore, we may specify a `name` argument for `EpochScoring`, and whether to use training data (by setting `on_train=True`) or validation data (which is the default)." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "AsbHe77Uq41c" }, "outputs": [], "source": [ "auc = EpochScoring(scoring='roc_auc', lower_is_better=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "pz-tGJaeq41e" }, "source": [ "Finally, we pass the scoring callback to the `callbacks` parameter as a list and then call `fit`. Notice that we get the printed scores and color highlighting for free." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "O7HFlTkoq41e" }, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " callbacks=[auc],\n", ")" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "YhOWx_7Tq41e", "outputId": "2ac14ec6-3b4c-4b1f-84e7-676a0dfef1be", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch roc_auc train_loss valid_acc valid_loss dur\n", "------- --------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.6936\u001b[0m \u001b[32m0.7299\u001b[0m \u001b[35m0.5550\u001b[0m \u001b[31m0.6742\u001b[0m 0.0176\n", " 2 \u001b[36m0.7103\u001b[0m \u001b[32m0.6848\u001b[0m \u001b[35m0.6600\u001b[0m \u001b[31m0.6601\u001b[0m 0.0208\n", " 3 \u001b[36m0.7155\u001b[0m \u001b[32m0.6550\u001b[0m \u001b[35m0.6900\u001b[0m \u001b[31m0.6536\u001b[0m 0.0244\n", " 4 \u001b[36m0.7255\u001b[0m \u001b[32m0.6355\u001b[0m \u001b[35m0.7200\u001b[0m \u001b[31m0.6485\u001b[0m 0.0179\n", " 5 \u001b[36m0.7340\u001b[0m 0.6380 \u001b[35m0.7250\u001b[0m \u001b[31m0.6422\u001b[0m 0.0186\n", " 6 \u001b[36m0.7373\u001b[0m \u001b[32m0.6268\u001b[0m \u001b[35m0.7400\u001b[0m \u001b[31m0.6363\u001b[0m 0.0200\n", " 7 \u001b[36m0.7445\u001b[0m \u001b[32m0.6157\u001b[0m 0.7400 \u001b[31m0.6317\u001b[0m 0.0244\n", " 8 \u001b[36m0.7477\u001b[0m \u001b[32m0.6128\u001b[0m \u001b[35m0.7450\u001b[0m \u001b[31m0.6258\u001b[0m 0.0195\n", " 9 \u001b[36m0.7573\u001b[0m \u001b[32m0.6068\u001b[0m 0.7150 \u001b[31m0.6153\u001b[0m 0.0182\n", " 10 \u001b[36m0.7616\u001b[0m \u001b[32m0.5958\u001b[0m 0.7350 \u001b[31m0.6105\u001b[0m 0.0266\n", " 11 \u001b[36m0.7684\u001b[0m \u001b[32m0.5819\u001b[0m 0.7300 \u001b[31m0.6010\u001b[0m 0.0188\n", " 12 \u001b[36m0.7712\u001b[0m 0.5847 0.7000 \u001b[31m0.5935\u001b[0m 0.0197\n", " 13 \u001b[36m0.7719\u001b[0m \u001b[32m0.5659\u001b[0m 0.7250 \u001b[31m0.5895\u001b[0m 0.0168\n", " 14 \u001b[36m0.7746\u001b[0m \u001b[32m0.5561\u001b[0m 0.7300 \u001b[31m0.5831\u001b[0m 0.0202\n", " 15 \u001b[36m0.7789\u001b[0m 0.5632 0.7400 \u001b[31m0.5779\u001b[0m 0.0194\n", " 16 \u001b[36m0.7839\u001b[0m \u001b[32m0.5352\u001b[0m 0.7250 \u001b[31m0.5730\u001b[0m 0.0197\n", " 17 0.7839 0.5495 0.7350 \u001b[31m0.5693\u001b[0m 0.0214\n", " 18 0.7816 0.5473 0.7350 0.5721 0.0225\n", " 19 \u001b[36m0.7903\u001b[0m 0.5436 0.7450 \u001b[31m0.5592\u001b[0m 0.0167\n", " 20 \u001b[36m0.7948\u001b[0m 0.5430 0.7450 \u001b[31m0.5519\u001b[0m 0.0186\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", ")" ] }, "metadata": {}, "execution_count": 37 } ], "source": [ "net.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": { "id": "Hz3yg7E9q41f" }, "source": [ "For information on how to write custom callbacks, have a look at the [Advanced_Usage](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb) notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "9ntHaTTXq41f" }, "source": [ "## Usage with sklearn `GridSearchCV`" ] }, { "cell_type": "markdown", "metadata": { "id": "vS_Fo2zRq41g" }, "source": [ "### Special prefixes" ] }, { "cell_type": "markdown", "metadata": { "id": "humHlLnxq41g" }, "source": [ "The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. So e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s." ] }, { "cell_type": "markdown", "metadata": { "id": "vW3xlLHfq41h" }, "source": [ "This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below." ] }, { "cell_type": "markdown", "metadata": { "id": "q9o8HamMq41h" }, "source": [ "In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "JiEKubsQq41h", "outputId": "c3280363-6e4e-4fb6-bfe7-6b4b30a77973", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "iterator_train, iterator_valid, callbacks, dataset, module, criterion, optimizer\n" ] } ], "source": [ "print(', '.join(net.prefixes_))" ] }, { "cell_type": "markdown", "metadata": { "id": "o7mFj9sfq41h" }, "source": [ "### Performing a grid search" ] }, { "cell_type": "markdown", "metadata": { "id": "4yWN9gsqq41i" }, "source": [ "Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), the module's dropout rate (`module__dropout`), and whether the SGD optimizer should use Nesterov momentum or not (`optimizer__nesterov`)." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "id": "McrqjYv-q41i" }, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "id": "CqxJik7Cq41i" }, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " optimizer__momentum=0.9,\n", " verbose=0,\n", " train_split=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "0Bkur8xQq41j" }, "source": [ "*Note*: We set the verbosity level to zero (`verbose=0`) to prevent too much print output from being shown. Also, we disable the skorch-internal train-validation split (`train_split=False`) because `GridSearchCV` already splits the training data for us. We only have to leave the skorch-internal split enabled for some specific uses, e.g. to perform `EarlyStopping`." ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "8RMclj0wq41j" }, "outputs": [], "source": [ "params = {\n", " 'lr': [0.05, 0.1],\n", " 'module__num_units': [10, 20],\n", " 'module__dropout': [0, 0.5],\n", " 'optimizer__nesterov': [False, True],\n", "}" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "id": "fNkU7Zf1q41k" }, "outputs": [], "source": [ "gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "id": "Thjiwvb4q41k", "outputId": "b1690c32-a635-46e7-9c3d-66ab4410479e", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False; total time= 0.2s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n", "[CV] END lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True; total time= 0.3s\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "GridSearchCV(cv=3,\n", " estimator=[uninitialized](\n", " module=,\n", "),\n", " param_grid={'lr': [0.05, 0.1], 'module__dropout': [0, 0.5],\n", " 'module__num_units': [10, 20],\n", " 'optimizer__nesterov': [False, True]},\n", " refit=False, scoring='accuracy', verbose=2)" ] }, "metadata": {}, "execution_count": 43 } ], "source": [ "gs.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": true, "id": "GBEqgMYSq41k", "outputId": "8914f39b-5df7-4061-a701-3eaa3a1e543c", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.8780367193540846 {'lr': 0.1, 'module__dropout': 0, 'module__num_units': 20, 'optimizer__nesterov': False}\n" ] } ], "source": [ "print(gs.best_score_, gs.best_params_)" ] }, { "cell_type": "markdown", "metadata": { "id": "W2wFkMO5q41l" }, "source": [ "Of course, we could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case we just prefix the parameter by the name of the net (e.g. `net__module__num_units`)." ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13 (default, Mar 28 2022, 08:03:21) [MSC v.1916 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "bd97b8bffa4d3737e84826bc3d37be3046061822757ce35137ab82ad4c5a2016" } }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }