{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# SGD: Weighted samples\n\nPlot decision function of a weighted dataset, where the size of points\nis proportional to its weight.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import linear_model\n\n# we create 20 points\nnp.random.seed(0)\nX = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]\ny = [1] * 10 + [-1] * 10\nsample_weight = 100 * np.abs(np.random.randn(20))\n# and assign a bigger weight to the last 10 samples\nsample_weight[:10] *= 10\n\n# plot the weighted data points\nxx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))\nfig, ax = plt.subplots()\nax.scatter(\n X[:, 0],\n X[:, 1],\n c=y,\n s=sample_weight,\n alpha=0.9,\n cmap=plt.cm.bone,\n edgecolor=\"black\",\n)\n\n# fit the unweighted model\nclf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)\nclf.fit(X, y)\nZ = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])\nZ = Z.reshape(xx.shape)\nno_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=[\"solid\"])\n\n# fit the weighted model\nclf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)\nclf.fit(X, y, sample_weight=sample_weight)\nZ = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])\nZ = Z.reshape(xx.shape)\nsamples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=[\"dashed\"])\n\nno_weights_handles, _ = no_weights.legend_elements()\nweights_handles, _ = samples_weights.legend_elements()\nax.legend(\n [no_weights_handles[0], weights_handles[0]],\n [\"no weights\", \"with weights\"],\n loc=\"lower left\",\n)\n\nax.set(xticks=(), yticks=())\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }