{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook was put together by [Jake Vanderplas](http://www.vanderplas.com). Source and license info is on [GitHub](https://github.com/jakevdp/sklearn_tutorial/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Supervised Learning In-Depth: Support Vector Machines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously we introduced supervised machine learning.\n", "There are many supervised learning algorithms available; here we'll go into brief detail one of the most powerful and interesting methods: **Support Vector Machines (SVMs)**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from scipy import stats\n", "\n", "plt.style.use('seaborn')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Motivating Support Vector Machines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Support Vector Machines (SVMs) are a powerful supervised learning algorithm used for **classification** or for **regression**. SVMs are a **discriminative** classifier: that is, they draw a boundary between clusters of data.\n", "\n", "Let's show a quick example of support vector classification. First we need to create a dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets.samples_generator import make_blobs\n", "X, y = make_blobs(n_samples=50, centers=2,\n", " random_state=0, cluster_std=0.60)\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A discriminative classifier attempts to draw a line between the two sets of data. Immediately we see a problem: such a line is ill-posed! For example, we could come up with several possibilities which perfectly discriminate between the classes in this example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xfit = np.linspace(-1, 3.5)\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "\n", "for m, b in [(1, 0.65), (0.5, 1.6), (-0.2, 2.9)]:\n", " plt.plot(xfit, m * xfit + b, '-k')\n", "\n", "plt.xlim(-1, 3.5);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are three *very* different separaters which perfectly discriminate between these samples. Depending on which you choose, a new data point will be classified almost entirely differently!\n", "\n", "How can we improve on this?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Support Vector Machines: Maximizing the *Margin*\n", "\n", "Support vector machines are one way to address this.\n", "What support vector machined do is to not only draw a line, but consider a *region* about the line of some given width. Here's an example of what it might look like:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xfit = np.linspace(-1, 3.5)\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "\n", "for m, b, d in [(1, 0.65, 0.33), (0.5, 1.6, 0.55), (-0.2, 2.9, 0.2)]:\n", " yfit = m * xfit + b\n", " plt.plot(xfit, yfit, '-k')\n", " plt.fill_between(xfit, yfit - d, yfit + d, edgecolor='none', color='#AAAAAA', alpha=0.4)\n", "\n", "plt.xlim(-1, 3.5);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice here that if we want to maximize this width, the middle fit is clearly the best.\n", "This is the intuition of **support vector machines**, which optimize a linear discriminant model in conjunction with a **margin** representing the perpendicular distance between the datasets." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Fitting a Support Vector Machine\n", "\n", "Now we'll fit a Support Vector Machine Classifier to these points. While the mathematical details of the likelihood model are interesting, we'll let you read about those elsewhere. Instead, we'll just treat the scikit-learn algorithm as a black box which accomplishes the above task." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.svm import SVC # \"Support Vector Classifier\"\n", "clf = SVC(kernel='linear')\n", "clf.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To better visualize what's happening here, let's create a quick convenience function that will plot SVM decision boundaries for us:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_svc_decision_function(clf, ax=None):\n", " \"\"\"Plot the decision function for a 2D SVC\"\"\"\n", " if ax is None:\n", " ax = plt.gca()\n", " x = np.linspace(plt.xlim()[0], plt.xlim()[1], 30)\n", " y = np.linspace(plt.ylim()[0], plt.ylim()[1], 30)\n", " Y, X = np.meshgrid(y, x)\n", " P = np.zeros_like(X)\n", " for i, xi in enumerate(x):\n", " for j, yj in enumerate(y):\n", " P[i, j] = clf.decision_function([[xi, yj]])\n", " # plot the margins\n", " ax.contour(X, Y, P, colors='k',\n", " levels=[-1, 0, 1], alpha=0.5,\n", " linestyles=['--', '-', '--'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "plot_svc_decision_function(clf);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the dashed lines touch a couple of the points: these points are the pivotal pieces of this fit, and are known as the *support vectors* (giving the algorithm its name).\n", "In scikit-learn, these are stored in the ``support_vectors_`` attribute of the classifier:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "plot_svc_decision_function(clf)\n", "plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n", " s=200, facecolors='none');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use IPython's ``interact`` functionality to explore how the distribution of points affects the support vectors and the discriminative fit.\n", "(This is only available in IPython 2.0+, and will not work in a static view)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ipywidgets import interact\n", "\n", "def plot_svm(N=10):\n", " X, y = make_blobs(n_samples=200, centers=2,\n", " random_state=0, cluster_std=0.60)\n", " X = X[:N]\n", " y = y[:N]\n", " clf = SVC(kernel='linear')\n", " clf.fit(X, y)\n", " plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", " plt.xlim(-1, 4)\n", " plt.ylim(-1, 6)\n", " plot_svc_decision_function(clf, plt.gca())\n", " plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n", " s=200, facecolors='none')\n", " \n", "interact(plot_svm, N=[10, 200], kernel='linear');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice the unique thing about SVM is that only the support vectors matter: that is, if you moved any of the other points without letting them cross the decision boundaries, they would have no effect on the classification results!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Going further: Kernel Methods\n", "\n", "Where SVM gets incredibly exciting is when it is used in conjunction with *kernels*.\n", "To motivate the need for kernels, let's look at some data which is not linearly separable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets.samples_generator import make_circles\n", "X, y = make_circles(100, factor=.1, noise=.1)\n", "\n", "clf = SVC(kernel='linear').fit(X, y)\n", "\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "plot_svc_decision_function(clf);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clearly, no linear discrimination will ever separate these data.\n", "One way we can adjust this is to apply a **kernel**, which is some functional transformation of the input data.\n", "\n", "For example, one simple model we could use is a **radial basis function**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "r = np.exp(-(X[:, 0] ** 2 + X[:, 1] ** 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we plot this along with our data, we can see the effect of it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mpl_toolkits import mplot3d\n", "\n", "def plot_3D(elev=30, azim=30):\n", " ax = plt.subplot(projection='3d')\n", " ax.scatter3D(X[:, 0], X[:, 1], r, c=y, s=50, cmap='spring')\n", " ax.view_init(elev=elev, azim=azim)\n", " ax.set_xlabel('x')\n", " ax.set_ylabel('y')\n", " ax.set_zlabel('r')\n", "\n", "interact(plot_3D, elev=(-90, 90), azip=(-180, 180));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that with this additional dimension, the data becomes trivially linearly separable!\n", "This is a relatively simple kernel; SVM has a more sophisticated version of this kernel built-in to the process. This is accomplished by using ``kernel='rbf'``, short for *radial basis function*:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clf = SVC(kernel='rbf')\n", "clf.fit(X, y)\n", "\n", "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n", "plot_svc_decision_function(clf)\n", "plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n", " s=200, facecolors='none');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here there are effectively $N$ basis functions: one centered at each point! Through a clever mathematical trick, this computation proceeds very efficiently using the \"Kernel Trick\", without actually constructing the matrix of kernel evaluations.\n", "\n", "We'll leave SVMs for the time being and take a look at another classification algorithm: Random Forests." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 1 }