{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering with pytorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clustering techniques are unsupervised learning algorithms that try to group unlabelled data into \"clusters\", using the (typically spatial) structure of the data itself.\n", "\n", "The easiest way to demonstrate how clustering works is to simply generate some data and show them in action. We'll start off by importing the libraries we'll be using today." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import operator\n", "import torch\n", "\n", "from fastai.core import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n_clusters=6\n", "n_samples =250" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To generate our data, we're going to pick 6 random points, which we'll call centroids, and for each point we're going to generate 250 random points about it." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "centroids = np.random.uniform(-35, 35, (n_clusters, 2))\n", "slices = [np.random.multivariate_normal(centroids[i], np.diag([5., 5.]), n_samples)\n", " for i in range(n_clusters)]\n", "data = np.concatenate(slices).astype(np.float32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we can see each centroid marked w/ X, and the coloring associated to each respective cluster." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def plot_data(centroids, data, n_samples):\n", " colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))\n", " for i, centroid in enumerate(centroids):\n", " samples = data[i*n_samples:(i+1)*n_samples]\n", " plt.scatter(samples[:,0], samples[:,1], c=colour[i], s=1)\n", " plt.plot(centroid[0], centroid[1], markersize=10, marker=\"x\", color='k', mew=5)\n", " plt.plot(centroid[0], centroid[1], markersize=5, marker=\"x\", color='m', mew=2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_data(centroids, data, n_samples)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Mean shift" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most people that have come across clustering algorithms have learnt about **k-means**. Mean shift clustering is a newer and less well-known approach, but it has some important advantages:\n", "* It doesn't require selecting the number of clusters in advance, but instead just requires a **bandwidth** to be specified, which can be easily chosen automatically\n", "* It can handle clusters of any shape, whereas k-means (without using special extensions) requires that clusters be roughly ball shaped.\n", "\n", "The algorithm is as follows:\n", "* For each data point x in the sample X, find the distance between that point x and every other point in X\n", "* Create weights for each point in X by using the **Gaussian kernel** of that point's distance to x\n", " * This weighting approach penalizes points further away from x\n", " * The rate at which the weights fall to zero is determined by the **bandwidth**, which is the standard deviation of the Gaussian\n", "![Gaussian](http://images.books24x7.com/bookimages/id_5642/fig11-10.jpg)\n", "* Update x as the weighted average of all other points in X, weighted based on the previous step\n", "\n", "This will iteratively push points that are close together even closer until they are next to each other." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So here's the definition of the gaussian kernel, which you may remember from high school..." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from numpy import exp, sqrt, array, abs" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def gaussian(d, bw): return exp(-0.5*((d/bw))**2) / (bw * math.sqrt(2*math.pi))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " This person at the science march certainly remembered!\n", "\n", "\n", "\n", "Since all of our distances are positive, we'll only be using the right-hand side of the gaussian. Here's what that looks like for a couple of different choices of bandwidth (bw)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x=np.linspace(0,5)\n", "fig, ax = plt.subplots()\n", "ax.plot(x, gaussian(x, 1), label='bw=1');\n", "ax.plot(x, gaussian(x, 2.5), label='bw=2.5')\n", "ax.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our implementation, we choose the bandwidth to be 2.5. (One easy way to choose bandwidth is to find which bandwidth covers one third of the data, which you can try implementing as an exercise.)\n", "\n", "We'll also need to be able to calculate the distance between points - here's the function we'll use:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def distance(x, X): return sqrt(((x-X)**2).sum(1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try it out. (More on how this function works shortly)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.41421, 0. , 3.60555])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = distance(array([2,3]), array([[1,2],[2,3],[-1,1]])); d" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can feed the distances into our gaussian function to see what weights we would get in this case." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.13598, 0.15958, 0.0564 ])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gaussian(d, 2.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can put these steps together to define a single iteration of the algorithm." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def meanshift_inner(x, X, bandwidth):\n", " # Find distance from point x to every other point in X\n", " dist = distance(x, X)\n", "\n", " # Use gaussian to turn into array of weights \n", " weight = gaussian(dist, bandwidth)\n", " # Weighted sum (see next section for details)\n", " return (weight[:,None]*X).sum(0) / weight.sum()\n", " \n", "def meanshift_iter(X, bandwidth=2.5):\n", " return np.array([meanshift_inner(x, X, bandwidth) for x in X])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "X=meanshift_iter(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The results show that, as we hoped, all the points have moved closer to their \"true\" cluster centers (even although the algorithm doesn't know where the centers actually are)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_data(centroids, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By repeating this a few times, we can make the clusters more accurate." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def meanshift(X, it=0, max_it=5, bandwidth=2.5, eps=0.000001):\n", " # perform meanshift once\n", " new_X = meanshift_iter(X, bandwidth=bandwidth)\n", " # if we're above the max number of allowed iters\n", " # or if our approximations have converged\n", " if it >= max_it or abs(X-new_X).sum()/abs(X.sum()) < eps:\n", " return new_X\n", " else:\n", " return meanshift(new_X, it+1, max_it, bandwidth, eps)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.09 s, sys: 0 ns, total: 1.09 s\n", "Wall time: 1.09 s\n" ] } ], "source": [ "%time X=meanshift(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that mean shift clustering has almost reproduced our original clustering. The one exception are the very close clusters, but if we really wanted to differentiate them we could lower the bandwidth.\n", "\n", "What is impressive is that this algorithm nearly reproduced the original clusters without telling it how many clusters there should be. (In the chart below we are offsetting the centroids a bit to the right, otherwise we couldn't be able to see the points since they're now on top of each other)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE5VJREFUeJzt3W9sZFd5x/Hvk2xIWdHKYLYhk0R12satlsoNwopA7Ys1f0poVlkXQRVeVCmNbIqoFAISDUWq7ReVQEhNkVoCdiHNC9SUUli2G/4lkVeoUgs4NGwTwjpLWJTEgTim2xJZCk3y9IWvnfGud73rmfF47vl+pJHvnHN97znyzE/XZ86cG5mJJKn+Luh2AyRJ28PAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBViV6sHiIhfAL4BXFwd7/OZORERVwJ3Af3A/cAfZebPz3asV77ylTkwMNBqkySpKPfff//Tmblns/1aDnzgWeANmflMRFwE/FtEfAV4P3BbZt4VEZ8EbgJuP9uBBgYGmJuba0OTJKkcEfGjc9mv5SGdXPFM9fSi6pHAG4DPV+V3AqOtnkuStHVtGcOPiAsj4gHgKeAe4AfAycx8rtrlceCydpxLqpOlpaXzKpda0ZbAz8znM/Nq4HLgGuA3z/V3I2I8IuYiYm5xcbEdzZF6wuTkJENDQ8zPz68rn5+fZ2hoiMnJye40TLXV1lk6mXkSmAVeD/RFxOpnBJcDT5zhd6Yzczgzh/fs2fQzB6kWJicnmZqaYmFhgZGRkbXQn5+fZ2RkhIWFBaampgx9tVXLgR8ReyKir9p+KfBm4GFWgv/t1W43Al9q9VxSHayG/arV0L/77rvXwn6Voa92ascV/qXAbEQcBb4N3JOZh4E/B94fEcdZmZr56TacS+ppS0tLzMzMrCsbZZTlhWX279/P8sIyo6fMb5iZmXFMX23R8rTMzDwKvGaD8kdZGc+XVOnv72d2dnbtSn6UUW7mZg5wgCmmmGCCAQYAOMhBGo0Gs7Oz9Pf3d7fhqgW/aStts8HBQWZnZ2k0GhzhCCc4wQAD3MEdDDDACU5whCNrYT84ONjtJqsmDHypCwYHB5menuYkJ5lial3dFFOc5CTT09OGvdrKwJe6YH5+nvHxcfroY4KJdXUTTNBHH+Pj46dN2ZRaYeBL26x56uU+9q0N47yLd60N7+xj32lTNqVWRWZ2uw1rhoeH07V0VGdLS0sMDQ2tm3o5yihHOMJJTtJHH/vYx0EOrtU3Gg2OHj3qB7c6o4i4PzOHN9vPK3xpG/X39zM2Nrau7CAH2d3YzeHDh9nd2L0u7AHGxsYMe7WFgS9ts8nJSSYmXhy3X52Nc911163N3lk1MTHhF6/UNu1YHlnSRiJe3D5l6HQ1xGdmZtZNvVydsjkyMsLY2Jhhr7ZyDF/qlLME/qqlpaUNh2vOVC5txDF8qQecKdQNe3WCQzpSp+yg/54l8Apfkoph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klSIlgM/Iq6IiNmI+F5EPBQRN1flr4iIeyLikerny1tvriRpq9pxhf8c8IHM3Au8DnhvROwFbgXuy8yrgPuq55KkLmk58DPzycz8TrX9M+Bh4DLgAHBntdudwGir55IkbV1bx/AjYgB4DfBN4JLMfLKq+jFwSTvPJUk6P20L/Ih4GfAvwPsy83+b6zIzgQ3v6BwR4xExFxFzi4uL7WqOJOkUbQn8iLiIlbD/bGZ+oSr+SURcWtVfCjy10e9m5nRmDmfm8J49e9rRHEnSBtoxSyeATwMPZ+ZfN1UdAm6stm8EvtTquSRJW7erDcf4HeCPgP+KiAeqsr8APgJ8LiJuAn4E/GEbziVJ2qKWAz8z/w2IM1S/sdXjS5Law2/aSlIhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFaEvgR8RnIuKpiHiwqewVEXFPRDxS/Xx5O84lSdqadl3h/wNw7SlltwL3ZeZVwH3Vc0lSl7Ql8DPzG8BPTyk+ANxZbd8JjLbjXJKkrenkGP4lmflktf1j4JIOnkuStIlt+dA2MxPIjeoiYjwi5iJibnFxcTuaI0lF6mTg/yQiLgWofj610U6ZOZ2Zw5k5vGfPng42R5LK1snAPwTcWG3fCHypg+eSJG2iXdMy/xH4d+A3IuLxiLgJ+Ajw5oh4BHhT9VyS1CW72nGQzHznGare2I7jS5Ja5zdtJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCdDzwI+LaiDgWEccj4tZOn0+StLGOBn5EXAj8HfBWYC/wzojY28lzSpI21ukr/GuA45n5aGb+HLgLONDhc0qSNtDpwL8MeKzp+eNV2ZqIGI+IuYiYW1xc7HBzJKlcXf/QNjOnM3M4M4f37NnT7eZIUm11OvCfAK5oen55VSZJ2madDvxvA1dFxJUR8RLgBuBQh88pSdrArk4ePDOfi4g/A74GXAh8JjMf6uQ5JUkb62jgA2Tml4Evd/o8kqSz6/qHtpKk7WHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhWgr8iHhHRDwUES9ExPApdR+KiOMRcSwi3tJaMyVJrdrV4u8/CLwN+FRzYUTsBW4AXg00gHsjYjAzn2/xfJKkLWrpCj8zH87MYxtUHQDuysxnM/OHwHHgmlbOJUl1s7S0dF7lrerUGP5lwGNNzx+vyiRJwOTkJENDQ8zPz68rn5+fZ2hoiMnJybafc9PAj4h7I+LBDR4H2tGAiBiPiLmImFtcXGzHISVpR5ucnGRqaoqFhQVGRkbWQn9+fp6RkREWFhaYmppqe+hvOoafmW/awnGfAK5oen55VbbR8aeBaYDh4eHcwrkkqWeshv2q1dCfnp5mfHychYWFtbrV/doV/J0a0jkE3BARF0fElcBVwLc6dC5J6glLS0vMzMysKxtllOWFZfbv38/ywjKjjK6rn5mZaduYfqvTMv8gIh4HXg/cHRFfA8jMh4DPAd8Dvgq81xk6kkrX39/P7OwsjUYDWAn7m7mZ27iNAQa4jdu4mZvXQr/RaDA7O0t/f39bzh+ZO2cUZXh4OOfm5rrdDEnqqNWx+uWF5bWwX3WCE9zCLexu7GZ2dpbBwcFNjxcR92fm8Gb7+U1bSdpmg4ODTE9Pc5KTTDG1rm6KKU5ykunp6XMK+/Nh4EvSNpufn2d8fJw++phgYl3dBBP00cf4+PhpUzZbZeBL0jZqnnq5j30MMMAJTvAu3sUJTjDAAPvYd9qUzXZwDF+StsnS0hJDQ0Prpl6OMsoRjnCSk/TRxz72cZCDa/WNRoOjR4+e9YNbx/AlaYfp7+9nbGxsXdlBDrK7sZvDhw+zu7F7XdgDjI2NtW2WjoEvSdtocnKSiYkXx+1Xp15ed91166ZsAkxMTLT127atrpYpSTpPqyE+MzOzburl4OAgs7OzjIyMMDY21valFRzDl6Q2uv7H/8kLu+GCZTj0qtecdd+lpaUNh2vOVH4mjuFLUhe8sBsuuGDl52bOFOrtGrM/lYEvSW10wTK88MLKz53GMXxJaqO1YZxf6m47NuIVviQVwsCXpEIY+JJUiGICf7tvFixJO00Rgd+NmwVL0k5T+8Dv1s2CJWmnqXXgn+lmwXffffda2K8y9CXVXW0Dv9s3C5aknaa2gd/tmwVL0k5T28CHF1eeazQaHOHI2t1k7uCOtbvMHOHIWti3+/6RkrST1DrwoXs3C5aknab2gd+tmwVL0k5T68Dv5s2CJWmnqe0NUDp1s2BJ2mmKvwFKt28WLEk7TW0DH7p7s2BJva9ua3C1FPgR8bGI+H5EHI2IL0ZEX1PdhyLieEQci4i3tN7Uszv8zO1rj2aroX/q1MvmKZuGvaRT1XINrszc8gP4PWBXtf1R4KPV9l7gu8DFwJXAD4ALNzvea1/72tyqf/3ZJ9YeG3n66afPq1xSuSYmJhJIIBuNRh47diwzM48dO5aNRmOtbmJiorsNrQBzeQ6Z3dIVfmZ+PTOfq57+B3B5tX0AuCszn83MHwLHgWtaOVertvtmwZJ6U53X4GrnGP6fAF+pti8DHmuqe7wq65j9L3vP2kOStqLua3BtGvgRcW9EPLjB40DTPh8GngM+e74NiIjxiJiLiLnFxcXz/XVJapu6r8G1a7MdMvNNZ6uPiD8G9gNvrMaSAJ4Armja7fKqbKPjTwPTsDIPf/MmS1LnrE7oGBkZ4cjCEQ5wYG0NLqCn1+BqdZbOtcAHgeszc7mp6hBwQ0RcHBFXAlcB32rlXJK0Xeq6BlerY/h/C/wicE9EPBARnwTIzIeAzwHfA74KvDczn2/xXJK0Leq6Blers3R+PTOvyMyrq8efNtX9VWb+Wmb+RmZ+5WzHkaSdos5rcNV2LR1JOl+9ugZX8WvpSNLZXD9zP4/+/TjXz9y/Vlb3NbgMfElF+pv4FL/6f0/xN/GpdeV1XoPLwJdUpPflu3n0ol/mffnu0+rqugaXY/iS2mZpaWnD4Y0zle90vdIfx/DVFXVbTlbnro6rS9ZtDS4DX21Txze8zs3qgmOnTlVsnuLYawuN1dK5LKm5XY9WlkdWd/XacrJqn+a/ffNr4PDhw+v+9r4GOodzXB656yHf/DDwe5Nv+HI9/fTTp/2NRxnNPvoSyD76cpTR014b3oeivc418B3SUUvqvpyszq7uq0vWjYGvlviGV/NUxSMcWVt+4A7uWFuWoFdXl6wbA18t8w2vuq4uWTcGvtrCN3zZ6rq6ZN0Y+GoL3/DlqvPqknVj4KtlvuHLtbS0tO7G3gc5yMf5OLdwCyc4wS3cwsf5+NqCY6uvAT+07w4DXy3xDV+2/t0XMbb/d9eV1Wl1ybox8NWSui8nq/WeuX2UF24f5Znbq6m237+PyaufZWL8D9f2qdPqkrVzLpP1t+vhF696l9+0LcPznziQ+YkDKz8zM5f/J/M7X8hc/p+cmJhY97dftfoa8G/fOZzjF69cLVNtMzk5yczMzGlTL1fH+MfGxry663HP3D7KbmAZeNl7Dp5W3yurS9bNua6WaeDrnE3Fi9sTZ3jZ+IaXtp/LI6sr6racrFQnBr4kFWJXtxug3nGmYRxJvcErfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCrGjllaIiEXgRx049CuBpztw3G6zX73FfvWWXurXr2Tmns122lGB3ykRMXcu60z0GvvVW+xXb6ljvxzSkaRCGPiSVIhSAn+62w3oEPvVW+xXb6ldv4oYw5cklXOFL0nFq3XgR8THIuL7EXE0Ir4YEX1NdR+KiOMRcSwi3tLNdp6viHhHRDwUES9ExPApdT3bL4CIuLZq+/GIuLXb7dmqiPhMRDwVEQ82lb0iIu6JiEeqny/vZhu3IiKuiIjZiPhe9Rq8uSrv6b5FxC9ExLci4rtVv6aq8isj4pvV6/GfIuIl3W5rK2od+MA9wG9l5hAwD3wIICL2AjcArwauBT4RERd2rZXn70HgbcA3mgt7vV9VW/8OeCuwF3hn1ade9A+s/A2a3Qrcl5lXAfdVz3vNc8AHMnMv8DrgvdXfqNf79izwhsz8beBq4NqIeB3wUeC2zPx14L+Bm7rYxpbVOvAz8+uZ+Vz19D+Ay6vtA8BdmflsZv4QOA5c0402bkVmPpyZxzao6ul+sdLW45n5aGb+HLiLlT71nMz8BvDTU4oPAHdW23cCo9vaqDbIzCcz8zvV9s+Ah4HL6PG+5YpnqqcXVY8E3gB8virvuX6dqtaBf4o/Ab5SbV8GPNZU93hV1ut6vV+93v7NXJKZT1bbPwYu6WZjWhURA8BrgG9Sg75FxIUR8QDwFCujAz8ATjZdNPb867Hnb3EYEfcCr9qg6sOZ+aVqnw+z8q/oZ7ezba04l36pd2VmRkTPTpGLiJcB/wK8LzP/NyLW6nq1b5n5PHB19VnfF4Hf7HKT2q7nAz8z33S2+oj4Y2A/8MZ8cQ7qE8AVTbtdXpXtGJv16wx2fL820evt38xPIuLSzHwyIi5l5Uqy50TERayE/Wcz8wtVcS36BpCZJyNiFng90BcRu6qr/J5/PdZ6SCcirgU+CFyfmctNVYeAGyLi4oi4ErgK+FY32thmvd6vbwNXVTMjXsLKB9CHutymdjoE3Fht3wj03H9qsXIp/2ng4cz866aqnu5bROxZncUXES8F3szK5xOzwNur3XquX6fJzNo+WPnQ8jHggerxyaa6D7MyRncMeGu323qe/foDVsYTnwV+AnytDv2q2v/7rMyo+gErw1ddb9MW+/GPwJPA/1V/q5uAflZmsDwC3Au8otvt3EK/fpeVDzOPNr2vfr/X+wYMAf9Z9etB4C+r8l9l5aLpOPDPwMXdbmsrD79pK0mFqPWQjiTpRQa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mF+H9q5qfTLaUnxQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_data(centroids+2, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Broadcasting" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "How did our distance function `sqrt(((x-X)**2).sum(1))` work over a matrix without us writing any loops? The trick is that we used *broadcasting*. The term broadcasting was first used by Numpy, although is now used in other libraries such as [Tensorflow](https://www.tensorflow.org/performance/xla/broadcasting) and Matlab; the rules can vary by library.\n", "\n", "From the [Numpy Documentation](https://docs.scipy.org/doc/numpy-1.10.0/user/basics.broadcasting.html):\n", "\n", " The term broadcasting describes how numpy treats arrays with \n", " different shapes during arithmetic operations. Subject to certain \n", " constraints, the smaller array is “broadcast” across the larger \n", " array so that they have compatible shapes.Broadcasting provides a \n", " means of vectorizing array operations so that looping occurs in C\n", " instead of Python. It does this without making needless copies of \n", " data and usually leads to efficient algorithm implementations.\n", " \n", "In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors.\n", "\n", "Operators (+,-,\\*,/,>,<,==) are usually element-wise. Here's some examples of element-wise operations:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(array([12, 14, 3]), array([False, True, True]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = np.array([10, 6, -4])\n", "b = np.array([2, 8, 7])\n", "\n", "a + b, a < b" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Now this next example clearly can't be element-wise, since the second parameter is a scalar, not a 1d array." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([ True, True, False])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a > 0" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "So how did this work? The trick was that numpy automatically *broadcast* the scalar `0` so that had the same `shape` as a. We can manually see how numpy broadcasts by using `broadcast_to()`." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(3,)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.shape" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(0, a.shape)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(3,)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(0, a.shape).shape" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Here's another example." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([11, 7, -3])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a + 1" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "It works with higher-dimensional arrays too, for instance 2d (matrices):" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[1, 2, 3],\n", " [4, 5, 6],\n", " [7, 8, 9]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = np.array([[1, 2, 3], [4,5,6], [7,8,9]]); m" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 2, 4, 6],\n", " [ 8, 10, 12],\n", " [14, 16, 18]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m * 2" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(3, 3)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.shape" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[2, 2, 2],\n", " [2, 2, 2],\n", " [2, 2, 2]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(2, m.shape)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We can use the same trick to broadcast a vector to a matrix:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([10, 20, 30])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = np.array([10,20,30]); c" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[11, 22, 33],\n", " [14, 25, 36],\n", " [17, 28, 39]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + c" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Let's see what numpy has done with `c` in this case:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[10, 20, 30],\n", " [10, 20, 30],\n", " [10, 20, 30]])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(c, m.shape)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Interesting - we see that it has duplicated `c` across rows. What if `c` was a column vector, i.e. a 3x1 array?" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[10],\n", " [20],\n", " [30]])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Indexing an axis with None adds a unit axis in that location\n", "cc = c[:,None]; cc" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[11, 12, 13],\n", " [24, 25, 26],\n", " [37, 38, 39]])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + cc" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Let's see what numpy has done with `c` in this case:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[10, 10, 10],\n", " [20, 20, 20],\n", " [30, 30, 30]])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(cc, m.shape)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Note that numpy isn't actually replicating the memory of the axes being broadcast - it's simply looping over the same locations multiple times. This is very efficient both for compute and memory.\n", "\n", "The behaviour of numpy's broadcasting seems quite intuitive, but you'll want to remember the explicit broadcasting rules to use this technique effectively:\n", "\n", "When operating on two arrays, Numpy/PyTorch compares their shapes element-wise. It starts with the **trailing dimensions**, and works its way forward. Two dimensions are **compatible** when\n", "\n", "- They are equal, or\n", "- One of them is 1.\n", "\n", "When axes have the same dimension, no broadcasting is required. Any axes of dimension 1 are replicated to match the other array.\n", "\n", "Arrays do not need to have the same number of dimensions. For example, if you have a 256 x 256 x 3 array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:\n", "\n", " Image (3d array): 256 x 256 x 3\n", " Scale (1d array): 3\n", " Result (3d array): 256 x 256 x 3\n", " \n", "Numpy will insert additional unit axes as required to make the array few fewer dimensions math. So in this case the Scale array would be first reshaped automatically to 1x1x3, and then broadcast to 256 x 256 x 3. The [numpy documentation](https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html#general-broadcasting-rules) includes several examples of what dimensions can and can not be broadcast together." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We can now see how our `distance()` function works:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[1, 1],\n", " [0, 0],\n", " [9, 4]])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a=array([2,3])\n", "b=array([[1,2],[2,3],[-1,1]])\n", "c=(a-b)**2; c" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 1, 2],\n", " [ 2, 3],\n", " [-1, 1]])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([0.13598, 0.15958, 0.0564 ])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w=gaussian(sqrt(c.sum(1)), 2.5); w" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "...and we can also now pull apart our weighted average:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "((3,), (3, 2))" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w.shape, b.shape" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[0.13598],\n", " [0.15958],\n", " [0.0564 ]])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w[:,None]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.13598, 0.27196],\n", " [ 0.31915, 0.47873],\n", " [-0.0564 , 0.0564 ]])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w[:,None]*b" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([1.13288, 2.29314])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(w[:,None]*b).sum(0) / w.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GPU-accelerated mean shift in pytorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now look at using [PyTorch](http://pytorch.org/), a Python framework for dynamic neural networks with GPU acceleration, which was released by Facebook's AI team.\n", "\n", "PyTorch has two overlapping, yet distinct, purposes. As described in the [PyTorch documentation](http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html):\n", "\n", "\"pytorch\"\n", "\n", "The neural network functionality of PyTorch is built on top of the Numpy-like functionality for fast matrix computations on a GPU. Although the neural network purpose receives much more attention, both are very useful. Today we'll use PyTorch to accelerate our meanshift algorithm by running it on the GPU.\n", "\n", "If you want to learn more PyTorch, you can try this [introductory tutorial](http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) or this [tutorial to learn by examples](http://pytorch.org/tutorials/beginner/pytorch_with_examples.html).\n", "\n", "One advantage of pytorch is that it's very similar to numpy. For instance, in fact, our definitions of `gaussian` and `distance` and `meanshift_iter` are identical. So we'll simply import PyTorch's alternate implementations of the two numpy functions we use:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "from torch import exp, sqrt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we'll use the exact same code as before, but first convert our numpy array to a GPU PyTorch tensor." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "\n", "def meanshift_iter_torch(X, bandwidth=2.5):\n", " out = torch.stack([meanshift_inner(x, X, bandwidth) for x in X], 0)\n", " return to_gpu(out.cuda())\n", "\n", "def meanshift_torch(X_torch, it=0, max_it=5, bandwidth=2.5, eps=0.000001):\n", " new_X = meanshift_iter_torch(X_torch, bandwidth=bandwidth)\n", " if it >= max_it or abs(X_torch-new_X).sum()/abs(X_torch.sum()) < eps:\n", " return new_X\n", " else:\n", " return meanshift_torch(new_X, it+1, max_it, bandwidth, eps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try it out..." ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.57 s, sys: 96 ms, total: 1.66 s\n", "Wall time: 1.66 s\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE3RJREFUeJzt3W9sZFd5x/Hvk2wIrGhlcLZpJonqtI1BS+WCsCJQ+2LNnxKaVdZUbRVeVCGN7BZRaQmV2lCk2n5RqagSAaklYBfSvEBNKYVluymFJPIqqlQITkm3CWGdJSxK4kC8blclWik0ydMXvnbGu971rmfG47nn+5FGvnPO9b3nyDM/XZ85c25kJpKk+ruo2w2QJG0NA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUiB2tHiAiXg08CFxaHe9LmTkREdcA9wD9wMPA72XmT891rMsuuywHBgZabZIkFeXhhx8+kZm7Ntqv5cAHXgDekZnPR8QlwL9FxNeAjwB3ZOY9EfEZ4FbgznMdaGBggLm5uTY0SZLKERE/PJ/9Wh7SyWXPV08vqR4JvAP4UlV+NzDa6rkkSZvXljH8iLg4Ih4BngPuA74PnMzMF6tdngaubMe5pDpZWlq6oHKpFW0J/Mx8KTPfDFwFXAe88Xx/NyLGI2IuIuYWFxfb0RypJ0xOTjI0NMT8/Pya8vn5eYaGhpicnOxOw1RbbZ2lk5kngVng7UBfRKx8RnAV8MxZfmc6M4czc3jXrg0/c5BqYXJykqmpKRYWFhgZGVkN/fn5eUZGRlhYWGBqasrQV1u1HPgRsSsi+qrt1wDvBh5nOfh/u9rtZuCrrZ5LqoOVsF+xEvr33nvvativMPTVTu24wr8CmI2II8C3gfsy8xDwp8BHIuIYy1MzP9eGc0k9bWlpiZmZmTVlo4xyauEUe/fu5dTCKUZPm98wMzPjmL7aouVpmZl5BHjLOuVPsjyeL6nS39/P7Ozs6pX8KKPsZz/72McUU0wwwQADABzgAI1Gg9nZWfr7+7vbcNWC37SVttjg4CCzs7M0Gg0Oc5jjHGeAAe7iLgYY4DjHOczh1bAfHBzsdpNVEwa+1AWDg4NMT09zkpNMMbWmboopTnKS6elpw15tZeBLXTA/P8/4+Dh99DHBxJq6CSboo4/x8fEzpmxKrTDwpS3WPPVyD3tWh3Fu4ZbV4Z097DljyqbUqsjMbrdh1fDwcLqWjupsaWmJoaGhNVMvRxnlMIc5yUn66GMPezjAgdX6RqPBkSNH/OBWZxURD2fm8Eb7eYUvbaH+/n7GxsbWlB3gADsbOzl06BA7GzvXhD3A2NiYYa+2MPClLTY5OcnExCvj9iuzcW644YbV2TsrJiYm/OKV2qYdyyNLWk/EK9unDZ2uhPjMzMyaqZcrUzZHRkYYGxsz7NVWjuFLnXKOwF+xtLS07nDN2cql9TiGL/WAs4W6Ya9OcEhH6pRt9N+zBF7hS1IxDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IK0XLgR8TVETEbEd+NiMciYn9V/vqIuC8inqh+vq715kqSNqsdV/gvAn+cmbuBtwEfiojdwO3AA5l5LfBA9VyS1CUtB35mPpuZ/1Ft/wR4HLgS2AfcXe12NzDa6rkkSZvX1jH8iBgA3gJ8C7g8M5+tqn4EXN7Oc0mSLkzbAj8iXgv8E/DhzPzf5rrMTGDdOzpHxHhEzEXE3OLiYruaI0k6TVsCPyIuYTnsv5CZX66KfxwRV1T1VwDPrfe7mTmdmcOZObxr1652NEeStI52zNIJ4HPA45n5iaaqg8DN1fbNwFdbPZckafN2tOEYvwb8HvBfEfFIVfZnwF8CX4yIW4EfAr/bhnNJkjap5cDPzH8D4izV72z1+JKk9vCbtpJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIdoS+BHx+Yh4LiIebSp7fUTcFxFPVD9f145zSZI2p11X+H8HXH9a2e3AA5l5LfBA9VyS1CVtCfzMfBD479OK9wF3V9t3A6PtOJckaXM6OYZ/eWY+W23/CLi8g+eSJG1gSz60zcwEcr26iBiPiLmImFtcXNyK5khSkToZ+D+OiCsAqp/PrbdTZk5n5nBmDu/atauDzZGksnUy8A8CN1fbNwNf7eC5JEkbaNe0zL8H/h14Q0Q8HRG3An8JvDsingDeVT2XJHXJjnYcJDPff5aqd7bj+JKk1vlNW0kqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEB0P/Ii4PiKORsSxiLi90+eTJK2vo4EfERcDfwO8F9gNvD8idnfynJKk9XX6Cv864FhmPpmZPwXuAfZ1+JySpHV0OvCvBJ5qev50VbYqIsYjYi4i5hYXFzvcHEkqV9c/tM3M6cwczszhXbt2dbs5klRbnQ78Z4Crm55fVZVJkrZYpwP/28C1EXFNRLwKuAk42OFzSpLWsaOTB8/MFyPij4CvAxcDn8/Mxzp5TknS+joa+ACZ+S/Av3T6PJKkc+v6h7aSpK1h4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIVoK/Ij4nYh4LCJejojh0+o+GhHHIuJoRLyntWZKklq1o8XffxT4LeCzzYURsRu4CXgT0ADuj4jBzHypxfNJkjappSv8zHw8M4+uU7UPuCczX8jMHwDHgOtaOZck1c3S0tIFlbeqU2P4VwJPNT1/uiqTJAGTk5MMDQ0xPz+/pnx+fp6hoSEmJyfbfs4NAz8i7o+IR9d57GtHAyJiPCLmImJucXGxHYeUpG1tcnKSqakpFhYWGBkZWQ39+fl5RkZGWFhYYGpqqu2hv+EYfma+axPHfQa4uun5VVXZesefBqYBhoeHcxPnkqSesRL2K1ZCf3p6mvHxcRYWFlbrVvZrV/B3akjnIHBTRFwaEdcA1wIPdehcktQTlpaWmJmZWVM2yiinFk6xd+9eTi2cYpTRNfUzMzNtG9NvdVrm+yLiaeDtwL0R8XWAzHwM+CLwXeBfgQ85Q0dS6fr7+5mdnaXRaADLYb+f/dzBHQwwwB3cwX72r4Z+o9FgdnaW/v7+tpw/MrfPKMrw8HDOzc11uxmS1FErY/WnFk6thv2K4xznNm5jZ2Mns7OzDA4Obni8iHg4M4c32s9v2krSFhscHGR6epqTnGSKqTV1U0xxkpNMT0+fV9hfCANfkrbY/Pw84+Pj9NHHBBNr6iaYoI8+xsfHz5iy2SoDX5K2UPPUyz3sYYABjnOcW7iF4xxngAH2sOeMKZvt4Bi+JG2RpaUlhoaG1ky9HGWUwxzmJCfpo4897OEAB1brG40GR44cOecHt47hS9I209/fz9jY2JqyAxxgZ2Mnhw4dYmdj55qwBxgbG2vbLB0DX5K20OTkJBMTr4zbr0y9vOGGG9ZM2QSYmJho67dtW10tU5J0gVZCfGZmZs3Uy8HBQWZnZxkZGWFsbKztSys4hi9JbXTjj77DyzvholNw8Offcs59l5aW1h2uOVv52TiGL0ld8PJOuOii5Z8bOVuot2vM/nQGviS10UWn4OWXl39uN47hS1IbrQ7j/Gx327Eer/AlqRAGviQVwsCXpEIUE/hbfbNgSdpuigj8btwsWJK2m9oHfrduFixJ202tA/9sNwu+9957V8N+haEvqe5qG/jdvlmwJG03tQ38bt8sWJK2m9oGPryy8lyj0eAwh1fvJnMXd63eZeYwh1fDvt33j5Sk7aTWgQ/du1mwJG03tQ/8bt0sWJK2m1oHfjdvFixJ201tb4DSqZsFS9J2U/wNULp9s2BJ2m5qG/jQ3ZsFS+p9dVuDq6XAj4i/iojvRcSRiPhKRPQ11X00Io5FxNGIeE/rTT23Q8/fufpothL6p0+9bJ6yadhLOl0t1+DKzE0/gN8AdlTbHwc+Xm3vBv4TuBS4Bvg+cPFGx3vrW9+am/XPP/n06mM9J06cuKBySeWamJhIIIFsNBp59OjRzMw8evRoNhqN1bqJiYnuNrQCzOV5ZHZLV/iZ+Y3MfLF6+k3gqmp7H3BPZr6QmT8AjgHXtXKuVm31zYIl9aY6r8HVzjH83we+Vm1fCTzVVPd0VdYxe1/7wdWHJG1G3dfg2jDwI+L+iHh0nce+pn0+BrwIfOFCGxAR4xExFxFzi4uLF/rrktQ2dV+Da8dGO2Tmu85VHxEfAPYC76zGkgCeAa5u2u2qqmy9408D07A8D3/jJktS56xM6BgZGeHwwmH2sW91DS6gp9fganWWzvXAnwA3ZuappqqDwE0RcWlEXANcCzzUyrkkaavUdQ2uVsfw/xr4GeC+iHgkIj4DkJmPAV8Evgv8K/ChzHypxXNJ0pao6xpcrc7S+eXMvDoz31w9/rCp7i8y85cy8w2Z+bVzHUeStos6r8FV27V0JOlC9eoaXMWvpSNJ53LjzMM8+bfj3Djz8GpZ3dfgMvAlFemT8Vl+8f+e45Px2TXldV6Dy8CXVKQP5x/w5CU/x4fzD86oq+saXI7hS2qbpaWldYc3zla+3fVKfxzDV1fUbTlZnb86ri5ZtzW4DHy1TR3f8Do/KwuOnT5VsXmKY68tNFZL57Ok5lY9WlkeWd3Va8vJqn2a//bNr4FDhw6t+dv7GugcznN55K6HfPPDwO9NvuHLdeLEiTP+xqOMZh99CWQffTnK6BmvDe9D0V7nG/gO6agldV9OVudW99Ul68bAV0t8w6t5quJhDq8uP3AXd60uS9Crq0vWjYGvlvmGV11Xl6wbA19t4Ru+bHVdXbJuDHy1hW/4ctV5dcm6MfDVMt/w5VpaWlpzY+8DHOBTfIrbuI3jHOc2buNTfGp1wbGV14Af2neHga+W+IYvW91Xl6wbA18t8Q1flufvHOXlO0d5/s5XptrWeXXJ2jmfyfpb9fCLV73Lb9qW4aVP78v89L7ln6eZmJhY87dfsfIa8G/fOZznF69cLVNtMzk5yczMzBlTL1fG+MfGxry663HP3znKTuAU8NoPHjijvldWl6yb810t08DXeZuKV7YnzvKy8Q0vbT2XR1ZX1G05WalODHxJKsSObjdAveNswziSeoNX+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFWJbLa0QEYvADztw6MuAEx04brfZr95iv3pLL/XrFzJz10Y7bavA75SImDufdSZ6jf3qLfart9SxXw7pSFIhDHxJKkQpgT/d7QZ0iP3qLfart9SuX0WM4UuSyrnCl6Ti1TrwI+KvIuJ7EXEkIr4SEX1NdR+NiGMRcTQi3tPNdl6oiPidiHgsIl6OiOHT6nq2XwARcX3V9mMRcXu327NZEfH5iHguIh5tKnt9RNwXEU9UP1/XzTZuRkRcHRGzEfHd6jW4vyrv6b5FxKsj4qGI+M+qX1NV+TUR8a3q9fgPEfGqbre1FbUOfOA+4FcycwiYBz4KEBG7gZuANwHXA5+OiIu71soL9yjwW8CDzYW93q+qrX8DvBfYDby/6lMv+juW/wbNbgceyMxrgQeq573mReCPM3M38DbgQ9XfqNf79gLwjsz8VeDNwPUR8Tbg48AdmfnLwP8At3axjS2rdeBn5jcy88Xq6TeBq6rtfcA9mflCZv4AOAZc1402bkZmPp6ZR9ep6ul+sdzWY5n5ZGb+FLiH5T71nMx8EPjv04r3AXdX23cDo1vaqDbIzGcz8z+q7Z8AjwNX0uN9y2XPV08vqR4JvAP4UlXec/06Xa0D/zS/D3yt2r4SeKqp7umqrNf1er96vf0buTwzn622fwRc3s3GtCoiBoC3AN+iBn2LiIsj4hHgOZZHB74PnGy6aOz512PP3+IwIu4Hfn6dqo9l5lerfT7G8r+iX9jKtrXifPql3pWZGRE9O0UuIl4L/BPw4cz834hYrevVvmXmS8Cbq8/6vgK8sctNarueD/zMfNe56iPiA8Be4J35yhzUZ4Crm3a7qirbNjbq11ls+35toNfbv5EfR8QVmflsRFzB8pVkz4mIS1gO+y9k5per4lr0DSAzT0bELPB2oC8idlRX+T3/eqz1kE5EXA/8CXBjZp5qqjoI3BQRl0bENcC1wEPdaGOb9Xq/vg1cW82MeBXLH0Af7HKb2ukgcHO1fTPQc/+pxfKl/OeAxzPzE01VPd23iNi1MosvIl4DvJvlzydmgd+uduu5fp0hM2v7YPlDy6eAR6rHZ5rqPsbyGN1R4L3dbusF9ut9LI8nvgD8GPh6HfpVtf83WZ5R9X2Wh6+63qZN9uPvgWeB/6v+VrcC/SzPYHkCuB94fbfbuYl+/TrLH2YeaXpf/Wav9w0YAr5T9etR4M+r8l9k+aLpGPCPwKXdbmsrD79pK0mFqPWQjiTpFQa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mF+H/Y7qdg1/eFMQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "X_torch = to_gpu(torch.from_numpy(X))\n", "%time X = meanshift_torch(X_torch).cpu().numpy()\n", "plot_data(centroids+2, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It works, but this implementation actually takes longer. Oh dear! What do you think is causing this?\n", "\n", "Each iteration launches a new cuda kernel, which takes time and slows the algorithm down as a whole. Furthermore, each iteration doesn't have enough processing to do to fill up all of the threads of the GPU. To use the GPU effectively, we need to process a *batch* of data at a time." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GPU batched algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To process a batch of data, we need batched versions of our functions. Here's a version of `distance()` that works on batches:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def distance_b(a,b): return sqrt(((a[None,:] - b[:,None]) ** 2).sum(2))" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 0.6297 0.8461 0.9324\n", " 0.1305 0.2712 0.1951\n", "[torch.FloatTensor of size 2x3]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a=torch.rand(2,2)\n", "b=torch.rand(3,2)\n", "distance_b(b, a)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note how the two parameters to `distance_b()` have a unit axis added in two different places (`a[None,:]` and `b[:,None]`). This is a handy trick which effectively generalizes the concept of an 'outer product' to any function. In this case, we use it to get the distance from every point in `a` (our batch) to every point in `b` (the whole dataset).\n", "\n", "Now that we have a suitable distance function, we can make some minor updates to our meanshift function to handle batches of data:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "def meanshift_gpu(X, it=0, max_it=5, bandwidth=2.5, eps=0.000001):\n", " weights = gaussian(distance_b(X, X), bandwidth)\n", " num = (weights[:,:,None] * X).sum(1)\n", " X_new = num / weights.sum(1)[:,None]\n", " \n", " if it >= max_it or abs(X_new - X).sum()/abs(x.sum()) < eps:\n", " return X_new\n", " else:\n", " return meanshift_gpu(X_new, it+1, max_it, bandwidth, eps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Although each iteration still has to launch a new cuda kernel, there are now fewer iterations, and the acceleration from updating a batch of points more than makes up for it." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 60 ms, sys: 60 ms, total: 120 ms\n", "Wall time: 124 ms\n" ] } ], "source": [ "X_torch = to_gpu(torch.from_numpy(data))\n", "%time X = meanshift_gpu(X_torch).cpu().numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's more like it! We've gone from 1660ms to 124ms, which is a speedup of 13.5! Oh, and it even gives the right answer!" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAE5VJREFUeJzt3W9sZFd5x/Hvk2xIWdHKYLYhk0R12satlsoNwopA7Ys1f0poVlkXQRVeVCmNbIqoFAISDUWq7ReVQEhNkVoCdiHNC9SUUli2G/4lkVeoUgs4NGwTwjpLWJTEgTim2xJZCk3y9IWvnfGud73rmfF47vl+pJHvnHN97znyzE/XZ86cG5mJJKn+Luh2AyRJ28PAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBViV6sHiIhfAL4BXFwd7/OZORERVwJ3Af3A/cAfZebPz3asV77ylTkwMNBqkySpKPfff//Tmblns/1aDnzgWeANmflMRFwE/FtEfAV4P3BbZt4VEZ8EbgJuP9uBBgYGmJuba0OTJKkcEfGjc9mv5SGdXPFM9fSi6pHAG4DPV+V3AqOtnkuStHVtGcOPiAsj4gHgKeAe4AfAycx8rtrlceCydpxLqpOlpaXzKpda0ZbAz8znM/Nq4HLgGuA3z/V3I2I8IuYiYm5xcbEdzZF6wuTkJENDQ8zPz68rn5+fZ2hoiMnJye40TLXV1lk6mXkSmAVeD/RFxOpnBJcDT5zhd6Yzczgzh/fs2fQzB6kWJicnmZqaYmFhgZGRkbXQn5+fZ2RkhIWFBaampgx9tVXLgR8ReyKir9p+KfBm4GFWgv/t1W43Al9q9VxSHayG/arV0L/77rvXwn6Voa92ascV/qXAbEQcBb4N3JOZh4E/B94fEcdZmZr56TacS+ppS0tLzMzMrCsbZZTlhWX279/P8sIyo6fMb5iZmXFMX23R8rTMzDwKvGaD8kdZGc+XVOnv72d2dnbtSn6UUW7mZg5wgCmmmGCCAQYAOMhBGo0Gs7Oz9Pf3d7fhqgW/aStts8HBQWZnZ2k0GhzhCCc4wQAD3MEdDDDACU5whCNrYT84ONjtJqsmDHypCwYHB5menuYkJ5lial3dFFOc5CTT09OGvdrKwJe6YH5+nvHxcfroY4KJdXUTTNBHH+Pj46dN2ZRaYeBL26x56uU+9q0N47yLd60N7+xj32lTNqVWRWZ2uw1rhoeH07V0VGdLS0sMDQ2tm3o5yihHOMJJTtJHH/vYx0EOrtU3Gg2OHj3qB7c6o4i4PzOHN9vPK3xpG/X39zM2Nrau7CAH2d3YzeHDh9nd2L0u7AHGxsYMe7WFgS9ts8nJSSYmXhy3X52Nc911163N3lk1MTHhF6/UNu1YHlnSRiJe3D5l6HQ1xGdmZtZNvVydsjkyMsLY2Jhhr7ZyDF/qlLME/qqlpaUNh2vOVC5txDF8qQecKdQNe3WCQzpSp+yg/54l8Apfkoph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klSIlgM/Iq6IiNmI+F5EPBQRN1flr4iIeyLikerny1tvriRpq9pxhf8c8IHM3Au8DnhvROwFbgXuy8yrgPuq55KkLmk58DPzycz8TrX9M+Bh4DLgAHBntdudwGir55IkbV1bx/AjYgB4DfBN4JLMfLKq+jFwSTvPJUk6P20L/Ih4GfAvwPsy83+b6zIzgQ3v6BwR4xExFxFzi4uL7WqOJOkUbQn8iLiIlbD/bGZ+oSr+SURcWtVfCjy10e9m5nRmDmfm8J49e9rRHEnSBtoxSyeATwMPZ+ZfN1UdAm6stm8EvtTquSRJW7erDcf4HeCPgP+KiAeqsr8APgJ8LiJuAn4E/GEbziVJ2qKWAz8z/w2IM1S/sdXjS5Law2/aSlIhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFaEvgR8RnIuKpiHiwqewVEXFPRDxS/Xx5O84lSdqadl3h/wNw7SlltwL3ZeZVwH3Vc0lSl7Ql8DPzG8BPTyk+ANxZbd8JjLbjXJKkrenkGP4lmflktf1j4JIOnkuStIlt+dA2MxPIjeoiYjwi5iJibnFxcTuaI0lF6mTg/yQiLgWofj610U6ZOZ2Zw5k5vGfPng42R5LK1snAPwTcWG3fCHypg+eSJG2iXdMy/xH4d+A3IuLxiLgJ+Ajw5oh4BHhT9VyS1CW72nGQzHznGare2I7jS5Ja5zdtJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCdDzwI+LaiDgWEccj4tZOn0+StLGOBn5EXAj8HfBWYC/wzojY28lzSpI21ukr/GuA45n5aGb+HLgLONDhc0qSNtDpwL8MeKzp+eNV2ZqIGI+IuYiYW1xc7HBzJKlcXf/QNjOnM3M4M4f37NnT7eZIUm11OvCfAK5oen55VSZJ2madDvxvA1dFxJUR8RLgBuBQh88pSdrArk4ePDOfi4g/A74GXAh8JjMf6uQ5JUkb62jgA2Tml4Evd/o8kqSz6/qHtpKk7WHgS1IhDHxJKoSBL0mFMPAlqRAGviQVwsCXpEIY+JJUCANfkgph4EtSIQx8SSqEgS9JhTDwJakQBr4kFcLAl6RCGPiSVAgDX5IKYeBLUiEMfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCmHgS1IhWgr8iHhHRDwUES9ExPApdR+KiOMRcSwi3tJaMyVJrdrV4u8/CLwN+FRzYUTsBW4AXg00gHsjYjAzn2/xfJKkLWrpCj8zH87MYxtUHQDuysxnM/OHwHHgmlbOJUl1s7S0dF7lrerUGP5lwGNNzx+vyiRJwOTkJENDQ8zPz68rn5+fZ2hoiMnJybafc9PAj4h7I+LBDR4H2tGAiBiPiLmImFtcXGzHISVpR5ucnGRqaoqFhQVGRkbWQn9+fp6RkREWFhaYmppqe+hvOoafmW/awnGfAK5oen55VbbR8aeBaYDh4eHcwrkkqWeshv2q1dCfnp5mfHychYWFtbrV/doV/J0a0jkE3BARF0fElcBVwLc6dC5J6glLS0vMzMysKxtllOWFZfbv38/ywjKjjK6rn5mZaduYfqvTMv8gIh4HXg/cHRFfA8jMh4DPAd8Dvgq81xk6kkrX39/P7OwsjUYDWAn7m7mZ27iNAQa4jdu4mZvXQr/RaDA7O0t/f39bzh+ZO2cUZXh4OOfm5rrdDEnqqNWx+uWF5bWwX3WCE9zCLexu7GZ2dpbBwcFNjxcR92fm8Gb7+U1bSdpmg4ODTE9Pc5KTTDG1rm6KKU5ykunp6XMK+/Nh4EvSNpufn2d8fJw++phgYl3dBBP00cf4+PhpUzZbZeBL0jZqnnq5j30MMMAJTvAu3sUJTjDAAPvYd9qUzXZwDF+StsnS0hJDQ0Prpl6OMsoRjnCSk/TRxz72cZCDa/WNRoOjR4+e9YNbx/AlaYfp7+9nbGxsXdlBDrK7sZvDhw+zu7F7XdgDjI2NtW2WjoEvSdtocnKSiYkXx+1Xp15ed91166ZsAkxMTLT127atrpYpSTpPqyE+MzOzburl4OAgs7OzjIyMMDY21valFRzDl6Q2uv7H/8kLu+GCZTj0qtecdd+lpaUNh2vOVH4mjuFLUhe8sBsuuGDl52bOFOrtGrM/lYEvSW10wTK88MLKz53GMXxJaqO1YZxf6m47NuIVviQVwsCXpEIY+JJUiGICf7tvFixJO00Rgd+NmwVL0k5T+8Dv1s2CJWmnqXXgn+lmwXffffda2K8y9CXVXW0Dv9s3C5aknaa2gd/tmwVL0k5T28CHF1eeazQaHOHI2t1k7uCOtbvMHOHIWti3+/6RkrST1DrwoXs3C5aknab2gd+tmwVL0k5T68Dv5s2CJWmnqe0NUDp1s2BJ2mmKvwFKt28WLEk7TW0DH7p7s2BJva9ua3C1FPgR8bGI+H5EHI2IL0ZEX1PdhyLieEQci4i3tN7Uszv8zO1rj2aroX/q1MvmKZuGvaRT1XINrszc8gP4PWBXtf1R4KPV9l7gu8DFwJXAD4ALNzvea1/72tyqf/3ZJ9YeG3n66afPq1xSuSYmJhJIIBuNRh47diwzM48dO5aNRmOtbmJiorsNrQBzeQ6Z3dIVfmZ+PTOfq57+B3B5tX0AuCszn83MHwLHgWtaOVertvtmwZJ6U53X4GrnGP6fAF+pti8DHmuqe7wq65j9L3vP2kOStqLua3BtGvgRcW9EPLjB40DTPh8GngM+e74NiIjxiJiLiLnFxcXz/XVJapu6r8G1a7MdMvNNZ6uPiD8G9gNvrMaSAJ4Armja7fKqbKPjTwPTsDIPf/MmS1LnrE7oGBkZ4cjCEQ5wYG0NLqCn1+BqdZbOtcAHgeszc7mp6hBwQ0RcHBFXAlcB32rlXJK0Xeq6BlerY/h/C/wicE9EPBARnwTIzIeAzwHfA74KvDczn2/xXJK0Leq6Blers3R+PTOvyMyrq8efNtX9VWb+Wmb+RmZ+5WzHkaSdos5rcNV2LR1JOl+9ugZX8WvpSNLZXD9zP4/+/TjXz9y/Vlb3NbgMfElF+pv4FL/6f0/xN/GpdeV1XoPLwJdUpPflu3n0ol/mffnu0+rqugaXY/iS2mZpaWnD4Y0zle90vdIfx/DVFXVbTlbnro6rS9ZtDS4DX21Txze8zs3qgmOnTlVsnuLYawuN1dK5LKm5XY9WlkdWd/XacrJqn+a/ffNr4PDhw+v+9r4GOodzXB656yHf/DDwe5Nv+HI9/fTTp/2NRxnNPvoSyD76cpTR014b3oeivc418B3SUUvqvpyszq7uq0vWjYGvlviGV/NUxSMcWVt+4A7uWFuWoFdXl6wbA18t8w2vuq4uWTcGvtrCN3zZ6rq6ZN0Y+GoL3/DlqvPqknVj4KtlvuHLtbS0tO7G3gc5yMf5OLdwCyc4wS3cwsf5+NqCY6uvAT+07w4DXy3xDV+2/t0XMbb/d9eV1Wl1ybox8NWSui8nq/WeuX2UF24f5Znbq6m237+PyaufZWL8D9f2qdPqkrVzLpP1t+vhF696l9+0LcPznziQ+YkDKz8zM5f/J/M7X8hc/p+cmJhY97dftfoa8G/fOZzjF69cLVNtMzk5yczMzGlTL1fH+MfGxry663HP3D7KbmAZeNl7Dp5W3yurS9bNua6WaeDrnE3Fi9sTZ3jZ+IaXtp/LI6sr6racrFQnBr4kFWJXtxug3nGmYRxJvcErfEkqhIEvSYUw8CWpEAa+JBXCwJekQhj4klQIA1+SCrGjllaIiEXgRx049CuBpztw3G6zX73FfvWWXurXr2Tmns122lGB3ykRMXcu60z0GvvVW+xXb6ljvxzSkaRCGPiSVIhSAn+62w3oEPvVW+xXb6ldv4oYw5cklXOFL0nFq3XgR8THIuL7EXE0Ir4YEX1NdR+KiOMRcSwi3tLNdp6viHhHRDwUES9ExPApdT3bL4CIuLZq+/GIuLXb7dmqiPhMRDwVEQ82lb0iIu6JiEeqny/vZhu3IiKuiIjZiPhe9Rq8uSrv6b5FxC9ExLci4rtVv6aq8isj4pvV6/GfIuIl3W5rK2od+MA9wG9l5hAwD3wIICL2AjcArwauBT4RERd2rZXn70HgbcA3mgt7vV9VW/8OeCuwF3hn1ade9A+s/A2a3Qrcl5lXAfdVz3vNc8AHMnMv8DrgvdXfqNf79izwhsz8beBq4NqIeB3wUeC2zPx14L+Bm7rYxpbVOvAz8+uZ+Vz19D+Ay6vtA8BdmflsZv4QOA5c0402bkVmPpyZxzao6ul+sdLW45n5aGb+HLiLlT71nMz8BvDTU4oPAHdW23cCo9vaqDbIzCcz8zvV9s+Ah4HL6PG+5YpnqqcXVY8E3gB8virvuX6dqtaBf4o/Ab5SbV8GPNZU93hV1ut6vV+93v7NXJKZT1bbPwYu6WZjWhURA8BrgG9Sg75FxIUR8QDwFCujAz8ATjZdNPb867Hnb3EYEfcCr9qg6sOZ+aVqnw+z8q/oZ7ezba04l36pd2VmRkTPTpGLiJcB/wK8LzP/NyLW6nq1b5n5PHB19VnfF4Hf7HKT2q7nAz8z33S2+oj4Y2A/8MZ8cQ7qE8AVTbtdXpXtGJv16wx2fL820evt38xPIuLSzHwyIi5l5Uqy50TERayE/Wcz8wtVcS36BpCZJyNiFng90BcRu6qr/J5/PdZ6SCcirgU+CFyfmctNVYeAGyLi4oi4ErgK+FY32thmvd6vbwNXVTMjXsLKB9CHutymdjoE3Fht3wj03H9qsXIp/2ng4cz866aqnu5bROxZncUXES8F3szK5xOzwNur3XquX6fJzNo+WPnQ8jHggerxyaa6D7MyRncMeGu323qe/foDVsYTnwV+AnytDv2q2v/7rMyo+gErw1ddb9MW+/GPwJPA/1V/q5uAflZmsDwC3Au8otvt3EK/fpeVDzOPNr2vfr/X+wYMAf9Z9etB4C+r8l9l5aLpOPDPwMXdbmsrD79pK0mFqPWQjiTpRQa+JBXCwJekQhj4klQIA1+SCmHgS1IhDHxJKoSBL0mF+H9q5qfTLaUnxQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_data(centroids+2, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## course.fast.ai" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "If you found this interesting, you might enjoy the 30+ hours of deep learning lessons at [course.fast.ai](http://course.fast.ai). There's also a very active forum of deep learning practitioners and learners at [forums.fast.ai](http://forums.fast.ai). Hope to see you there! :)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }