{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Early stopping of Stochastic Gradient Descent\n\nStochastic Gradient Descent is an optimization technique which minimizes a loss\nfunction in a stochastic fashion, performing a gradient descent step sample by\nsample. In particular, it is a very efficient method to fit linear models.\n\nAs a stochastic method, the loss function is not necessarily decreasing at each\niteration, and convergence is only guaranteed in expectation. For this reason,\nmonitoring the convergence on the loss function can be difficult.\n\nAnother approach is to monitor convergence on a validation score. In this case,\nthe input data is split into a training set and a validation set. The model is\nthen fitted on the training set and the stopping criterion is based on the\nprediction score computed on the validation set. This enables us to find the\nleast number of iterations which is sufficient to build a model that\ngeneralizes well to unseen data and reduces the chance of over-fitting the\ntraining data.\n\nThis early stopping strategy is activated if ``early_stopping=True``; otherwise\nthe stopping criterion only uses the training loss on the entire input data. To\nbetter control the early stopping strategy, we can specify a parameter\n``validation_fraction`` which set the fraction of the input dataset that we\nkeep aside to compute the validation score. The optimization will continue\nuntil the validation score did not improve by at least ``tol`` during the last\n``n_iter_no_change`` iterations. The actual number of iterations is available\nat the attribute ``n_iter_``.\n\nThis example illustrates how the early stopping can used in the\n:class:`~sklearn.linear_model.SGDClassifier` model to achieve almost the same\naccuracy as compared to a model built without early stopping. This can\nsignificantly reduce training time. Note that scores differ between the\nstopping criteria even from early iterations because some of the training data\nis held out with the validation stopping criterion.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport sys\nimport time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nfrom sklearn import linear_model\nfrom sklearn.datasets import fetch_openml\nfrom sklearn.exceptions import ConvergenceWarning\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.utils import shuffle\nfrom sklearn.utils._testing import ignore_warnings\n\n\ndef load_mnist(n_samples=None, class_0=\"0\", class_1=\"8\"):\n \"\"\"Load MNIST, select two classes, shuffle and return only n_samples.\"\"\"\n # Load data from http://openml.org/d/554\n mnist = fetch_openml(\"mnist_784\", version=1, as_frame=False)\n\n # take only two classes for binary classification\n mask = np.logical_or(mnist.target == class_0, mnist.target == class_1)\n\n X, y = shuffle(mnist.data[mask], mnist.target[mask], random_state=42)\n if n_samples is not None:\n X, y = X[:n_samples], y[:n_samples]\n return X, y\n\n\n@ignore_warnings(category=ConvergenceWarning)\ndef fit_and_score(estimator, max_iter, X_train, X_test, y_train, y_test):\n \"\"\"Fit the estimator on the train set and score it on both sets\"\"\"\n estimator.set_params(max_iter=max_iter)\n estimator.set_params(random_state=0)\n\n start = time.time()\n estimator.fit(X_train, y_train)\n\n fit_time = time.time() - start\n n_iter = estimator.n_iter_\n train_score = estimator.score(X_train, y_train)\n test_score = estimator.score(X_test, y_test)\n\n return fit_time, n_iter, train_score, test_score\n\n\n# Define the estimators to compare\nestimator_dict = {\n \"No stopping criterion\": linear_model.SGDClassifier(n_iter_no_change=3),\n \"Training loss\": linear_model.SGDClassifier(\n early_stopping=False, n_iter_no_change=3, tol=0.1\n ),\n \"Validation score\": linear_model.SGDClassifier(\n early_stopping=True, n_iter_no_change=3, tol=0.0001, validation_fraction=0.2\n ),\n}\n\n# Load the dataset\nX, y = load_mnist(n_samples=10000)\nX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n\nresults = []\nfor estimator_name, estimator in estimator_dict.items():\n print(estimator_name + \": \", end=\"\")\n for max_iter in range(1, 50):\n print(\".\", end=\"\")\n sys.stdout.flush()\n\n fit_time, n_iter, train_score, test_score = fit_and_score(\n estimator, max_iter, X_train, X_test, y_train, y_test\n )\n\n results.append(\n (estimator_name, max_iter, fit_time, n_iter, train_score, test_score)\n )\n print(\"\")\n\n# Transform the results in a pandas dataframe for easy plotting\ncolumns = [\n \"Stopping criterion\",\n \"max_iter\",\n \"Fit time (sec)\",\n \"n_iter_\",\n \"Train score\",\n \"Test score\",\n]\nresults_df = pd.DataFrame(results, columns=columns)\n\n# Define what to plot\nlines = \"Stopping criterion\"\nx_axis = \"max_iter\"\nstyles = [\"-.\", \"--\", \"-\"]\n\n# First plot: train and test scores\nfig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(12, 4))\nfor ax, y_axis in zip(axes, [\"Train score\", \"Test score\"]):\n for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)):\n group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style)\n ax.set_title(y_axis)\n ax.legend(title=lines)\nfig.tight_layout()\n\n# Second plot: n_iter and fit time\nfig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 4))\nfor ax, y_axis in zip(axes, [\"n_iter_\", \"Fit time (sec)\"]):\n for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)):\n group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style)\n ax.set_title(y_axis)\n ax.legend(title=lines)\nfig.tight_layout()\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }