{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Compare Stochastic learning strategies for MLPClassifier\n\nThis example visualizes some training loss curves for different stochastic\nlearning strategies, including SGD and Adam. Because of time-constraints, we\nuse several small datasets, for which L-BFGS might be more suitable. The\ngeneral trend shown in these examples seems to carry over to larger datasets,\nhowever.\n\nNote that those results can be highly dependent on the value of\n``learning_rate_init``.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport warnings\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn import datasets\nfrom sklearn.exceptions import ConvergenceWarning\nfrom sklearn.neural_network import MLPClassifier\nfrom sklearn.preprocessing import MinMaxScaler\n\n# different learning rate schedules and momentum parameters\nparams = [\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"constant\",\n \"momentum\": 0,\n \"learning_rate_init\": 0.2,\n },\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"constant\",\n \"momentum\": 0.9,\n \"nesterovs_momentum\": False,\n \"learning_rate_init\": 0.2,\n },\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"constant\",\n \"momentum\": 0.9,\n \"nesterovs_momentum\": True,\n \"learning_rate_init\": 0.2,\n },\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"invscaling\",\n \"momentum\": 0,\n \"learning_rate_init\": 0.2,\n },\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"invscaling\",\n \"momentum\": 0.9,\n \"nesterovs_momentum\": False,\n \"learning_rate_init\": 0.2,\n },\n {\n \"solver\": \"sgd\",\n \"learning_rate\": \"invscaling\",\n \"momentum\": 0.9,\n \"nesterovs_momentum\": True,\n \"learning_rate_init\": 0.2,\n },\n {\"solver\": \"adam\", \"learning_rate_init\": 0.01},\n]\n\nlabels = [\n \"constant learning-rate\",\n \"constant with momentum\",\n \"constant with Nesterov's momentum\",\n \"inv-scaling learning-rate\",\n \"inv-scaling with momentum\",\n \"inv-scaling with Nesterov's momentum\",\n \"adam\",\n]\n\nplot_args = [\n {\"c\": \"red\", \"linestyle\": \"-\"},\n {\"c\": \"green\", \"linestyle\": \"-\"},\n {\"c\": \"blue\", \"linestyle\": \"-\"},\n {\"c\": \"red\", \"linestyle\": \"--\"},\n {\"c\": \"green\", \"linestyle\": \"--\"},\n {\"c\": \"blue\", \"linestyle\": \"--\"},\n {\"c\": \"black\", \"linestyle\": \"-\"},\n]\n\n\ndef plot_on_dataset(X, y, ax, name):\n # for each dataset, plot learning for each learning strategy\n print(\"\\nlearning on dataset %s\" % name)\n ax.set_title(name)\n\n X = MinMaxScaler().fit_transform(X)\n mlps = []\n if name == \"digits\":\n # digits is larger but converges fairly quickly\n max_iter = 15\n else:\n max_iter = 400\n\n for label, param in zip(labels, params):\n print(\"training: %s\" % label)\n mlp = MLPClassifier(random_state=0, max_iter=max_iter, **param)\n\n # some parameter combinations will not converge as can be seen on the\n # plots so they are ignored here\n with warnings.catch_warnings():\n warnings.filterwarnings(\n \"ignore\", category=ConvergenceWarning, module=\"sklearn\"\n )\n mlp.fit(X, y)\n\n mlps.append(mlp)\n print(\"Training set score: %f\" % mlp.score(X, y))\n print(\"Training set loss: %f\" % mlp.loss_)\n for mlp, label, args in zip(mlps, labels, plot_args):\n ax.plot(mlp.loss_curve_, label=label, **args)\n\n\nfig, axes = plt.subplots(2, 2, figsize=(15, 10))\n# load / generate some toy datasets\niris = datasets.load_iris()\nX_digits, y_digits = datasets.load_digits(return_X_y=True)\ndata_sets = [\n (iris.data, iris.target),\n (X_digits, y_digits),\n datasets.make_circles(noise=0.2, factor=0.5, random_state=1),\n datasets.make_moons(noise=0.3, random_state=0),\n]\n\nfor ax, data, name in zip(\n axes.ravel(), data_sets, [\"iris\", \"digits\", \"circles\", \"moons\"]\n):\n plot_on_dataset(*data, ax=ax, name=name)\n\nfig.legend(ax.get_lines(), labels, ncol=3, loc=\"upper center\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }