{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic usage" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*`skorch`* is designed to maximize interoperability between `sklearn` and `pytorch`. The aim is to keep 99% of the flexibility of `pytorch` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `skorch` and how it can be combined with `sklearn`.\n", "\n", "
\n", "\n", " Run in Google Colab \n", "\n", "View source on GitHub
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows you how to use the basic functionality of `skorch`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Table of contents" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* [Definition of the pytorch module](#Definition-of-the-pytorch-module)\n", "* [Training a classifier](#Training-a-classifier-and-making-predictions)\n", " * [Dataset](#A-toy-binary-classification-task)\n", " * [pytorch module](#Definition-of-the-pytorch-classification-module)\n", " * [Model training](#Defining-and-training-the-neural-net-classifier)\n", " * [Inference](#Making-predictions,-classification)\n", "* [Training a regressor](#Training-a-regressor)\n", " * [Dataset](#A-toy-regression-task)\n", " * [pytorch module](#Definition-of-the-pytorch-regression-module)\n", " * [Model training](#Defining-and-training-the-neural-net-regressor)\n", " * [Inference](#Making-predictions,-regression)\n", "* [Saving and loading a model](#Saving-and-loading-a-model)\n", " * [Whole model](#Saving-the-whole-model)\n", " * [Only parameters](#Saving-only-the-model-parameters)\n", "* [Usage with an sklearn Pipeline](#Usage-with-an-sklearn-Pipeline)\n", "* [Callbacks](#Callbacks)\n", "* [Grid search](#Usage-with-sklearn-GridSearchCV)\n", " * [Special prefixes](#Special-prefixes)\n", " * [Performing a grid search](#Performing-a-grid-search)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "! [ ! -z \"$COLAB_GPU\" ] && pip install torch skorch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "\n", "torch.manual_seed(0);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a classifier and making predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A toy binary classification task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We load a toy classification task from `sklearn`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import make_classification" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n", "X = X.astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "((1000, 20), (1000,), 0.5)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X.shape, y.shape, y.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Definition of the `pytorch` classification `module`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class ClassifierModule(nn.Module):\n", " def __init__(\n", " self,\n", " num_units=10,\n", " nonlin=F.relu,\n", " dropout=0.5,\n", " ):\n", " super(ClassifierModule, self).__init__()\n", " self.num_units = num_units\n", " self.nonlin = nonlin\n", " self.dropout = dropout\n", "\n", " self.dense0 = nn.Linear(20, num_units)\n", " self.nonlin = nonlin\n", " self.dropout = nn.Dropout(dropout)\n", " self.dense1 = nn.Linear(num_units, 10)\n", " self.output = nn.Linear(10, 2)\n", "\n", " def forward(self, X, **kwargs):\n", " X = self.nonlin(self.dense0(X))\n", " X = self.dropout(X)\n", " X = F.relu(self.dense1(X))\n", " X = F.softmax(self.output(X), dim=-1)\n", " return X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining and training the neural net classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional.\n", "\n", "*Note*: To use the CUDA backend, pass `device='cuda'` as an additional argument." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from skorch import NeuralNetClassifier" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", "# device='cuda', # uncomment this to train with CUDA\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Automatic pdb calling has been turned ON\n" ] } ], "source": [ "pdb on" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.6905\u001b[0m \u001b[32m0.6150\u001b[0m \u001b[35m0.6749\u001b[0m 0.0235\n", " 2 \u001b[36m0.6648\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6633\u001b[0m 0.0213\n", " 3 \u001b[36m0.6619\u001b[0m \u001b[32m0.6750\u001b[0m \u001b[35m0.6533\u001b[0m 0.0219\n", " 4 \u001b[36m0.6429\u001b[0m \u001b[32m0.6800\u001b[0m \u001b[35m0.6399\u001b[0m 0.0207\n", " 5 \u001b[36m0.6307\u001b[0m \u001b[32m0.6950\u001b[0m \u001b[35m0.6254\u001b[0m 0.0192\n", " 6 \u001b[36m0.6291\u001b[0m \u001b[32m0.7000\u001b[0m \u001b[35m0.6134\u001b[0m 0.0202\n", " 7 \u001b[36m0.6102\u001b[0m \u001b[32m0.7100\u001b[0m \u001b[35m0.6033\u001b[0m 0.0220\n", " 8 \u001b[36m0.6050\u001b[0m 0.7000 \u001b[35m0.5931\u001b[0m 0.0210\n", " 9 \u001b[36m0.5966\u001b[0m 0.7000 \u001b[35m0.5844\u001b[0m 0.0217\n", " 10 \u001b[36m0.5636\u001b[0m 0.7100 \u001b[35m0.5689\u001b[0m 0.0226\n", " 11 0.5757 \u001b[32m0.7200\u001b[0m \u001b[35m0.5628\u001b[0m 0.0196\n", " 12 0.5757 0.7200 \u001b[35m0.5520\u001b[0m 0.0190\n", " 13 \u001b[36m0.5559\u001b[0m \u001b[32m0.7300\u001b[0m \u001b[35m0.5459\u001b[0m 0.0218\n", " 14 \u001b[36m0.5541\u001b[0m 0.7300 \u001b[35m0.5424\u001b[0m 0.0206\n", " 15 0.5659 \u001b[32m0.7350\u001b[0m \u001b[35m0.5378\u001b[0m 0.0215\n", " 16 \u001b[36m0.5364\u001b[0m 0.7350 \u001b[35m0.5322\u001b[0m 0.0192\n", " 17 0.5456 0.7300 \u001b[35m0.5239\u001b[0m 0.0221\n", " 18 0.5476 \u001b[32m0.7450\u001b[0m 0.5260 0.0188\n", " 19 0.5499 \u001b[32m0.7500\u001b[0m 0.5249 0.0213\n", " 20 \u001b[36m0.5273\u001b[0m 0.7350 0.5251 0.0206\n" ] }, { "data": { "text/plain": [ "[initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making predictions, classification" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = net.predict(X[:5])\n", "y_pred" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.5349464 , 0.46505365],\n", " [0.8685093 , 0.1314907 ],\n", " [0.6860039 , 0.31399614],\n", " [0.9126012 , 0.08739878],\n", " [0.69675475, 0.30324525]], dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba = net.predict_proba(X[:5])\n", "y_proba" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a regressor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A toy regression task" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_regression" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)\n", "X_regr = X_regr.astype(np.float32)\n", "y_regr = y_regr.astype(np.float32) / 100\n", "y_regr = y_regr.reshape(-1, 1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "((1000, 20), (1000, 1), -6.4901485, 6.154505)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Note*: Regression currently requires the target to be 2-dimensional, hence the need to reshape. This should be fixed with an upcoming version of pytorch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Definition of the `pytorch` regression `module`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, define a vanilla neural network with two hidden layers. The main difference is that the output layer only has one unit and does not apply a softmax nonlinearity." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class RegressorModule(nn.Module):\n", " def __init__(\n", " self,\n", " num_units=10,\n", " nonlin=F.relu,\n", " ):\n", " super(RegressorModule, self).__init__()\n", " self.num_units = num_units\n", " self.nonlin = nonlin\n", "\n", " self.dense0 = nn.Linear(20, num_units)\n", " self.nonlin = nonlin\n", " self.dense1 = nn.Linear(num_units, 10)\n", " self.output = nn.Linear(10, 1)\n", "\n", " def forward(self, X, **kwargs):\n", " X = self.nonlin(self.dense0(X))\n", " X = F.relu(self.dense1(X))\n", " X = self.output(X)\n", " return X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining and training the neural net regressor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training a regressor is almost the same as training a classifier. Mainly, we use `NeuralNetRegressor` instead of `NeuralNetClassifier` (this is the same terminology as in `sklearn`)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from skorch import NeuralNetRegressor" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "net_regr = NeuralNetRegressor(\n", " RegressorModule,\n", " max_epochs=20,\n", " lr=0.1,\n", "# device='cuda', # uncomment this to train with CUDA\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " epoch train_loss valid_loss dur\n", "------- ------------ ------------ ------\n", " 1 \u001b[36m4.4168\u001b[0m \u001b[32m3.0788\u001b[0m 0.0292\n", " 2 \u001b[36m2.0120\u001b[0m \u001b[32m0.4565\u001b[0m 0.0270\n", " 3 \u001b[36m0.3343\u001b[0m \u001b[32m0.2262\u001b[0m 0.0263\n", " 4 \u001b[36m0.1851\u001b[0m \u001b[32m0.2223\u001b[0m 0.0257\n", " 5 \u001b[36m0.1491\u001b[0m \u001b[32m0.1068\u001b[0m 0.0242\n", " 6 \u001b[36m0.0946\u001b[0m 0.1207 0.0263\n", " 7 \u001b[36m0.0739\u001b[0m \u001b[32m0.0663\u001b[0m 0.0290\n", " 8 \u001b[36m0.0554\u001b[0m 0.0706 0.0298\n", " 9 \u001b[36m0.0437\u001b[0m \u001b[32m0.0461\u001b[0m 0.0337\n", " 10 \u001b[36m0.0372\u001b[0m 0.0469 0.0273\n", " 11 \u001b[36m0.0291\u001b[0m \u001b[32m0.0343\u001b[0m 0.0263\n", " 12 \u001b[36m0.0270\u001b[0m \u001b[32m0.0333\u001b[0m 0.0285\n", " 13 \u001b[36m0.0207\u001b[0m \u001b[32m0.0265\u001b[0m 0.0281\n", " 14 \u001b[36m0.0196\u001b[0m \u001b[32m0.0249\u001b[0m 0.0344\n", " 15 \u001b[36m0.0152\u001b[0m \u001b[32m0.0215\u001b[0m 0.0286\n", " 16 \u001b[36m0.0151\u001b[0m \u001b[32m0.0198\u001b[0m 0.0281\n", " 17 \u001b[36m0.0120\u001b[0m \u001b[32m0.0182\u001b[0m 0.0283\n", " 18 \u001b[36m0.0119\u001b[0m \u001b[32m0.0167\u001b[0m 0.0266\n", " 19 \u001b[36m0.0100\u001b[0m \u001b[32m0.0159\u001b[0m 0.0266\n", " 20 \u001b[36m0.0097\u001b[0m \u001b[32m0.0149\u001b[0m 0.0259\n" ] }, { "data": { "text/plain": [ "[initialized](\n", " module_=RegressorModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=1, bias=True)\n", " ),\n", ")" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net_regr.fit(X_regr, y_regr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making predictions, regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.4903931 ],\n", " [-1.4224019 ],\n", " [-0.77500594],\n", " [-0.06901944],\n", " [-0.3867012 ]], dtype=float32)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = net_regr.predict(X_regr[:5])\n", "y_pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving and loading a model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving the whole model" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import pickle" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "file_name = '/tmp/mymodel.pkl'" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/thomasfan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:241: UserWarning: Couldn't retrieve source code for container of type ClassifierModule. It won't be checked for correctness upon loading.\n", " \"type \" + obj.__name__ + \". It won't be checked \"\n" ] } ], "source": [ "with open(file_name, 'wb') as f:\n", " pickle.dump(net, f)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "with open(file_name, 'rb') as f:\n", " new_net = pickle.load(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving only the model parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "net.save_params(f_params=file_name) # a file handler also works" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# first initialize the model\n", "new_net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", ").initialize()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "new_net.load_params(file_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage with an `sklearn Pipeline`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "pipe = Pipeline([\n", " ('scale', StandardScaler()),\n", " ('net', net),\n", "])" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Re-initializing module!\n", " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.7243\u001b[0m \u001b[32m0.5000\u001b[0m \u001b[35m0.7105\u001b[0m 0.0184\n", " 2 \u001b[36m0.7057\u001b[0m 0.5000 \u001b[35m0.6996\u001b[0m 0.0207\n", " 3 \u001b[36m0.6971\u001b[0m 0.5000 \u001b[35m0.6949\u001b[0m 0.0192\n", " 4 \u001b[36m0.6936\u001b[0m \u001b[32m0.5050\u001b[0m \u001b[35m0.6929\u001b[0m 0.0224\n", " 5 \u001b[36m0.6923\u001b[0m \u001b[32m0.5400\u001b[0m \u001b[35m0.6916\u001b[0m 0.0210\n", " 6 \u001b[36m0.6905\u001b[0m 0.5000 \u001b[35m0.6906\u001b[0m 0.0189\n", " 7 \u001b[36m0.6894\u001b[0m 0.5100 \u001b[35m0.6899\u001b[0m 0.0194\n", " 8 \u001b[36m0.6891\u001b[0m 0.5150 \u001b[35m0.6892\u001b[0m 0.0186\n", " 9 0.6899 0.5250 \u001b[35m0.6885\u001b[0m 0.0202\n", " 10 \u001b[36m0.6844\u001b[0m 0.5300 \u001b[35m0.6876\u001b[0m 0.0189\n", " 11 0.6853 \u001b[32m0.5650\u001b[0m \u001b[35m0.6865\u001b[0m 0.0199\n", " 12 \u001b[36m0.6842\u001b[0m \u001b[32m0.5700\u001b[0m \u001b[35m0.6855\u001b[0m 0.0183\n", " 13 \u001b[36m0.6821\u001b[0m \u001b[32m0.5850\u001b[0m \u001b[35m0.6844\u001b[0m 0.0199\n", " 14 \u001b[36m0.6821\u001b[0m \u001b[32m0.6050\u001b[0m \u001b[35m0.6832\u001b[0m 0.0189\n", " 15 \u001b[36m0.6820\u001b[0m \u001b[32m0.6100\u001b[0m \u001b[35m0.6820\u001b[0m 0.0206\n", " 16 \u001b[36m0.6769\u001b[0m 0.6100 \u001b[35m0.6800\u001b[0m 0.0188\n", " 17 0.6784 \u001b[32m0.6200\u001b[0m \u001b[35m0.6780\u001b[0m 0.0219\n", " 18 \u001b[36m0.6763\u001b[0m \u001b[32m0.6450\u001b[0m \u001b[35m0.6761\u001b[0m 0.0233\n", " 19 \u001b[36m0.6704\u001b[0m \u001b[32m0.6550\u001b[0m \u001b[35m0.6729\u001b[0m 0.0254\n", " 20 \u001b[36m0.6691\u001b[0m \u001b[32m0.6750\u001b[0m \u001b[35m0.6699\u001b[0m 0.0252\n" ] }, { "data": { "text/plain": [ "Pipeline(memory=None,\n", " steps=[('scale', StandardScaler(copy=True, with_mean=True, with_std=True)), ('net', [initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", "))])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.5064775 , 0.49352255],\n", " [0.53243965, 0.46756038],\n", " [0.57306874, 0.42693123],\n", " [0.54179883, 0.45820117],\n", " [0.5528906 , 0.44710937]], dtype=float32)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba = pipe.predict_proba(X[:5])\n", "y_proba" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To save the whole pipeline, including the pytorch module, use `pickle`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from skorch.callbacks import EpochScoring" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate. We have 3 choices:\n", "\n", "* Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics).\n", "* Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that.\n", "* Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used.\n", "\n", "Note that this works exactly the same as scoring in `sklearn` does." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For our case here, since `sklearn` already implements AUC, we just pass the correct string `'roc_auc'`. We should also tell the callback that higher scores are better (to get the correct colors printed below -- by default, lower scores are assumed to be better). Furthermore, we may specify a `name` argument for `EpochScoring`, and whether to use training data (by setting `on_train=True`) or validation data (which is the default)." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "auc = EpochScoring(scoring='roc_auc', lower_is_better=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we pass the scoring callback to the `callbacks` parameter as a list and then call `fit`. Notice that we get the printed scores and color highlighting for free." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " callbacks=[auc],\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " epoch roc_auc train_loss valid_acc valid_loss dur\n", "------- --------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m0.6112\u001b[0m \u001b[32m0.7076\u001b[0m \u001b[35m0.5550\u001b[0m \u001b[31m0.6802\u001b[0m 0.0188\n", " 2 \u001b[36m0.6766\u001b[0m \u001b[32m0.6750\u001b[0m \u001b[35m0.6150\u001b[0m \u001b[31m0.6626\u001b[0m 0.0204\n", " 3 \u001b[36m0.7031\u001b[0m \u001b[32m0.6560\u001b[0m \u001b[35m0.6500\u001b[0m \u001b[31m0.6498\u001b[0m 0.0244\n", " 4 \u001b[36m0.7201\u001b[0m \u001b[32m0.6364\u001b[0m \u001b[35m0.6650\u001b[0m \u001b[31m0.6381\u001b[0m 0.0193\n", " 5 \u001b[36m0.7316\u001b[0m \u001b[32m0.6176\u001b[0m \u001b[35m0.6900\u001b[0m \u001b[31m0.6285\u001b[0m 0.0203\n", " 6 \u001b[36m0.7447\u001b[0m \u001b[32m0.6094\u001b[0m \u001b[35m0.7200\u001b[0m \u001b[31m0.6183\u001b[0m 0.0222\n", " 7 \u001b[36m0.7522\u001b[0m 0.6170 0.7200 \u001b[31m0.6090\u001b[0m 0.0188\n", " 8 \u001b[36m0.7567\u001b[0m \u001b[32m0.5786\u001b[0m 0.7150 \u001b[31m0.6032\u001b[0m 0.0197\n", " 9 \u001b[36m0.7630\u001b[0m 0.5850 0.7100 \u001b[31m0.5954\u001b[0m 0.0214\n", " 10 \u001b[36m0.7706\u001b[0m \u001b[32m0.5770\u001b[0m 0.7200 \u001b[31m0.5889\u001b[0m 0.0207\n", " 11 \u001b[36m0.7735\u001b[0m \u001b[32m0.5740\u001b[0m 0.7050 \u001b[31m0.5842\u001b[0m 0.0188\n", " 12 0.7729 0.5771 0.7100 0.5859 0.0186\n", " 13 \u001b[36m0.7792\u001b[0m \u001b[32m0.5557\u001b[0m 0.7000 \u001b[31m0.5745\u001b[0m 0.0178\n", " 14 \u001b[36m0.7825\u001b[0m 0.5810 0.7050 \u001b[31m0.5691\u001b[0m 0.0204\n", " 15 0.7824 0.5634 0.7200 \u001b[31m0.5691\u001b[0m 0.0194\n", " 16 0.7817 0.5778 0.7150 0.5704 0.0205\n", " 17 \u001b[36m0.7871\u001b[0m 0.5624 0.7150 \u001b[31m0.5633\u001b[0m 0.0188\n", " 18 0.7855 0.5613 0.7200 0.5660 0.0202\n", " 19 0.7792 0.5637 \u001b[35m0.7250\u001b[0m 0.5722 0.0194\n", " 20 0.7823 \u001b[32m0.5516\u001b[0m 0.7150 0.5681 0.0184\n" ] }, { "data": { "text/plain": [ "[initialized](\n", " module_=ClassifierModule(\n", " (dense0): Linear(in_features=20, out_features=10, bias=True)\n", " (dropout): Dropout(p=0.5)\n", " (dense1): Linear(in_features=10, out_features=10, bias=True)\n", " (output): Linear(in_features=10, out_features=2, bias=True)\n", " ),\n", ")" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For information on how to write custom callbacks, have a look at the [Advanced_Usage](https://nbviewer.jupyter.org/github/dnouri/skorch/blob/master/notebooks/Advanced_Usage.ipynb) notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage with sklearn `GridSearchCV`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Special prefixes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. So e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "module, iterator_train, iterator_valid, optimizer, criterion, callbacks, dataset\n" ] } ], "source": [ "print(', '.join(net.prefixes_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Performing a grid search" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), the module's dropout rate (`module__dropout`), and whether the SGD optimizer should use Nesterov momentum or not (`optimizer__nesterov`)." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "net = NeuralNetClassifier(\n", " ClassifierModule,\n", " max_epochs=20,\n", " lr=0.1,\n", " verbose=0,\n", " optimizer__momentum=0.9,\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "params = {\n", " 'lr': [0.05, 0.1],\n", " 'module__num_units': [10, 20],\n", " 'module__dropout': [0, 0.5],\n", " 'optimizer__nesterov': [False, True],\n", "}" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.3s remaining: 0.0s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.05, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=10, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=False, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True \n", "[CV] lr=0.1, module__dropout=0.5, module__num_units=20, optimizer__nesterov=True, total= 0.3s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Done 48 out of 48 | elapsed: 15.7s finished\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=3, error_score='raise-deprecating',\n", " estimator=[uninitialized](\n", " module=,\n", "),\n", " fit_params=None, iid='warn', n_jobs=None,\n", " param_grid={'lr': [0.05, 0.1], 'module__num_units': [10, 20], 'module__dropout': [0, 0.5], 'optimizer__nesterov': [False, True]},\n", " pre_dispatch='2*n_jobs', refit=False, return_train_score='warn',\n", " scoring='accuracy', verbose=2)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gs.fit(X, y)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.862 {'lr': 0.05, 'module__dropout': 0, 'module__num_units': 20, 'optimizer__nesterov': False}\n" ] } ], "source": [ "print(gs.best_score_, gs.best_params_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Of course, we could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case we just prefix the parameter by the name of the net (e.g. `net__module__num_units`)." ] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }