{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clustering techniques are unsupervised learning algorithms that try to group unlabelled data into \"clusters\", using the (typically spatial) structure of the data itself. It has many [applications](https://en.wikipedia.org/wiki/Cluster_analysis#Applications).\n", "\n", "The easiest way to demonstrate how clustering works is to simply generate some data and show them in action. We'll start off by importing the libraries we'll be using today." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math, matplotlib.pyplot as plt, operator, torch\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(42)\n", "torch.set_printoptions(precision=3, linewidth=140, sci_mode=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_clusters=6\n", "n_samples =250" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To generate our data, we're going to pick 6 random points, which we'll call centroids, and for each point we're going to generate 250 random points about it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "centroids = torch.rand(n_clusters, 2)*70-35" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.distributions.multivariate_normal import MultivariateNormal\n", "from torch import tensor" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sample(m): return MultivariateNormal(m, torch.diag(tensor([5.,5.]))).sample((n_samples,))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1500, 2])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "slices = [sample(c) for c in centroids]\n", "data = torch.cat(slices)\n", "data.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we can see each centroid marked w/ X, and the coloring associated to each respective cluster." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_data(centroids, data, n_samples, ax=None):\n", " if ax is None: _,ax = plt.subplots()\n", " for i, centroid in enumerate(centroids):\n", " samples = data[i*n_samples:(i+1)*n_samples]\n", " ax.scatter(samples[:,0], samples[:,1], s=1)\n", " ax.plot(*centroid, markersize=10, marker=\"x\", color='k', mew=5)\n", " ax.plot(*centroid, markersize=5, marker=\"x\", color='m', mew=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_data(centroids, data, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Mean shift" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most people that have come across clustering algorithms have learnt about **k-means**. Mean shift clustering is a newer and less well-known approach, but it has some important advantages:\n", "* It doesn't require selecting the number of clusters in advance, but instead just requires a **bandwidth** to be specified, which can be easily chosen automatically\n", "* It can handle clusters of any shape, whereas k-means (without using special extensions) requires that clusters be roughly ball shaped." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The algorithm is as follows:\n", "* For each data point x in the sample X, find the distance between that point x and every other point in X\n", "* Create weights for each point in X by using the **Gaussian kernel** of that point's distance to x\n", " * This weighting approach penalizes points further away from x\n", " * The rate at which the weights fall to zero is determined by the **bandwidth**, which is the standard deviation of the Gaussian\n", "* Update x as the weighted average of all other points in X, weighted based on the previous step\n", "\n", "This will iteratively push points that are close together even closer until they are next to each other." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 9.222, 11.604])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "midp = data.mean(0)\n", "midp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_data([midp]*6, data, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So here's the definition of the gaussian kernel, which you may remember from high school...\n", " This person at the science march certainly remembered!\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gaussian(d, bw): return torch.exp(-0.5*((d/bw))**2) / (bw*math.sqrt(2*math.pi))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_func(f):\n", " x = torch.linspace(0,10,100)\n", " plt.plot(x, f(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_func(partial(gaussian, bw=2.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "functools.partial" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "partial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our implementation, we choose the bandwidth to be 2.5. \n", "\n", "One easy way to choose bandwidth is to find which bandwidth covers one third of the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tri(d, i): return (-d+i).clamp_min(0)/i" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_func(partial(tri, i=8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = data.clone()\n", "x = data[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([26.204, 26.349])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2]), torch.Size([1500, 2]), torch.Size([1, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape,X.shape,x[None].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.000, 0.000],\n", " [ 0.513, -3.865],\n", " [-4.227, -2.345],\n", " [ 0.557, -3.685],\n", " [-5.033, -3.745],\n", " [-4.073, -0.638],\n", " [-3.415, -5.601],\n", " [-1.920, -5.686]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x[None]-X)[:8]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.000, 0.000],\n", " [ 0.513, -3.865],\n", " [-4.227, -2.345],\n", " [ 0.557, -3.685],\n", " [-5.033, -3.745],\n", " [-4.073, -0.638],\n", " [-3.415, -5.601],\n", " [-1.920, -5.686]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x-X)[:8]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.000, 3.899, 4.834, 3.726, 6.273, 4.122, 6.560, 6.002])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# rewrite using torch.einsum\n", "dist = ((x-X)**2).sum(1).sqrt()\n", "dist[:8]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.160, 0.047, 0.025, ..., 0.000, 0.000, 0.000])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight = gaussian(dist, 2.5)\n", "weight" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1500]), torch.Size([1500, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight.shape,X.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1500, 1])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight[:,None].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 4.182, 4.205],\n", " [ 1.215, 1.429],\n", " [ 0.749, 0.706],\n", " ...,\n", " [ 0.000, 0.000],\n", " [ 0.000, 0.000],\n", " [ 0.000, 0.000]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight[:,None]*X" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def one_update(X):\n", " for i, x in enumerate(X):\n", " dist = torch.sqrt(((x-X)**2).sum(1))\n", "# weight = gaussian(dist, 2.5)\n", " weight = tri(dist, 8)\n", " X[i] = (weight[:,None]*X).sum(0)/weight.sum()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def meanshift(data):\n", " X = data.clone()\n", " for it in range(5): one_update(X)\n", " return X" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 453 ms, sys: 0 ns, total: 453 ms\n", "Wall time: 452 ms\n" ] } ], "source": [ "%time X=meanshift(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUzElEQVR4nO3dUWyk13mf8eeNpDhZpC1tehNrJG0pIGJq2WEciBBctBekrcBKtaiYAELli2KrGCQCuMI68IXl6GJmCggwUDRboHHQkrA3ujCiCki8EnaTprJKwi2QRKECZy1lvaNFTNkLCtJqEjYNCCiQ/faCM9TMktw1OTMc8vD5AQPOd87MnHOw5H/PnDnfN5GZSJLK9GPD7oAkaXAMeUkqmCEvSQUz5CWpYIa8JBXs1mF3oNMHP/jBHBsbG3Y3JOlQefnll9/OzOPb1R2okB8bG2N5eXnY3ZCkQyUiXt+pzuUaSSqYIV+AZrO5q3JJR4chf8jVajUmJiZoNBpd5Y1Gg4mJCWq12nA6JulAMOQPsVqtRr1eZ3V1lenp6c2gbzQaTE9Ps7q6Sr1eN+ilI8yQP6TaAd/WDvoLFy5sBnybQS8dXYb8IdRsNllYWOgqm2GG9dV1Tp48yfrqOjPMdNUvLCy4Ri8dQYb8ITQ6Osri4iKVSgXYCPjTnOYMZxhjjDOc4TSnN4O+UqmwuLjI6OjoEHstaRgM+UNqfHx8M+iXWGKFFcYY4yxnGWOMFVZYYmkz4MfHx4fdZUlDYMgfYuPj48zPz7PGGnXqXXV16qyxxvz8vAEvHWGG/CHWaDSYm5tjhBGqVLvqqlQZYYS5ubkt2yslHR2G/CHVuU1yiqnNJZrHeGxz6WaKqS3bKyUdLXGQvv5vcnIyvXbNzTWbTSYmJrq2Sc4wwxJLrLHGCCNMMcU5zm3WVyoVLl686IevUoEi4uXMnNyuzpn8ITQ6Osrs7GxX2TnOcaxyjPPnz3Oscqwr4AFmZ2cNeGkXSrlciCF/SNVqNarV99bh27toHnrooa7tlQDVatWToaRdKOpyIZl5YG733XdfqkP1H7932+kh1WpWKpW8fPlyV/nly5ezUqlktVodcCelslSr1QQS6Prbav9NtesO0t8WsJw75Kpr8gdZ7Z903P+/Oz6s2WxuuxSzU7mk7V1/uRDYeJc8Pz/P3Nxc1+dgcHDeJbsmX7idgtyAl350pV4u5EB9M5Suc4PZu6T+al8upL01uX25kId5mDp1qlQZYwzY2OhwWC4X4kxeklpKvFxIzyEfET8RES9FxF9GxKsRUW+VfyAiXoiI11o/3997dyVpsEq7XEg/ZvLvAJ/IzF8APgY8GBEfB54AXszMe4AXW8eSdKCVdrmQnkO+tYPn71uHt7VuCTwMPN0qfxqu+8RCkg6YEi8X0pctlBFxC/Ay8LPAlzPzCxGxlpkjHY/528zcsmQTEXPAHMCJEyfue/3113vujyTt1mG+XMjAt1Bm5g8y82PAncD9EfHRXTx3PjMnM3Py+PHj/eiOJO1aqZcL6esWysxci4gl4EHgzYi4PTPfiIjbgbf62ZYk7cXYExc276986aGuuvaJTe0Tojp30XRur4SDcyLUzfRjd83xiBhp3f9J4AHgO8DzwKnWw04Bz/XaliQNWvu6UNdvk+zcXnlYAh76M5O/HXi6tS7/Y8CzmXk+Iv4EeDYiPgN8D3ikD21J0sDVajUef/zxLUsx4+PjB2INfje8do0kHXJeu0aSjihDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKtiRCvlms7mrckk67I5MyNdqNSYmJrZ8J2Oj0WBiYuLQXBtaknbjSIR8rVajXq9v+fLdzi/trdfrBr2k4hQf8u2Ab2sH/YULF7q+ygsw6CUVp+iQbzabLCwsdJXNMMP66jonT55kfXWdGWa66hcWFlyjl1SMokN+dHR08zsZYSPgT3OaM5xhjDHOcIbTnN4M+vZ3Oh6mr/aSpBspOuSh+8t3l1hihRXGGOMsZxljjBVWWGJpy5f2SlIJig952Aj6+fl51lijTr2rrk6dNdaYn5834CUV50iEfKPRYG5ujhFGqFLtqqtSZYQR5ubmtmyvlKTDrviQ79wmOcXU5hLNYzy2uXQzxdSW7ZWSVILIzGH3YdPk5GQuLy/37fWazSYTExNd2yRnmGGJJdZYY4QRppjiHOc26yuVChcvXvTDV0mHRkS8nJmT29UVPZMfHR1ldna2q+wc5zhWOcb58+c5VjnWFfAAs7OzBrykYvQc8hFxV0QsRsSliHg1Ik63yj8QES9ExGutn+/vvbu7V6vVqFbfW4dv76J56KGHurZXAlSrVU+GklSUW/vwGu8Cn8/Mv4iIfwS8HBEvAP8OeDEzvxQRTwBPAF/oQ3vbuvTPPrx5/8PfudRV1w7uhYWFrm2S7e2V09PTzM7OGvCSitP3NfmIeA747dZtKjPfiIjbgaXM/LkbPbeXNfkbhXxbs9ncdilmp3JJOgz2bU0+IsaAXwT+DPiZzHwDoPXzp3d4zlxELEfE8rVr1/rZnS12CnIDXlKp+rFcA0BE/BTw+8DnMvPvIuJHel5mzgPzsDGT32v7O83eJeko68tMPiJuYyPgv5aZf9AqfrO1TEPr51v9aEuS9KPrx+6aAL4CXMrM3+qoeh441bp/Cniu17YkSbvTj+WafwH8W+DbEfGtVtlvAl8Cno2IzwDfAx7pQ1uSpF3oOeQz8/8AOy3Af7LX15ck7V3RZ7xK0lFnyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8gVqNpu7KpdUrr6EfER8NSLeiohXOso+EBEvRMRrrZ/v70dburFarcbExASNRqOrvNFoMDExQa1WG07HJA1Fv2byvws8eF3ZE8CLmXkP8GLrWANUq9Wo1+usrq4yPT29GfSNRoPp6WlWV1ep1+sGvTRE+/1Ouy8hn5nfBP7muuKHgadb958GZvrRlrbXDvi2dtBfuHBhM+DbDHppOIbyTjsz+3IDxoBXOo7Xrqv/2x2eNwcsA8snTpxI7d7bb7+dlUolgc3bDDM5wkgCOcJIzjDTVV+pVPLtt98edtelI6NarXb9/V2+fDkzMy9fvtz191utVnf92sBy7pTNO1Xs9rbXkO+83XfffbsenDZ0/qLMMJOLLOZZzuYYY3mWs7nI4mbQd/6CSRq8zoDvDPrz589vmaDtJehvFPKD3F3zZkTcDtD6+dYA2zryxsfHWVxcpFKpsMQSK6wwxhhnOcsYY6ywwhJLVCoVFhcXGR8fH3aXpSOh2WyysLDQVTbDDOur65w8eZL11XVmrlvNXlhY6Nsa/SBD/nngVOv+KeC5AbYlNoJ+fn6eNdaoU++qq1NnjTXm5+cNeGkfjY6Obk7AYCPgT3OaM5xhjDHOcIbTnN4M+vZEbHR0tC/t92sL5e8BfwL8XERcjYjPAF8CfikiXgN+qXWsAWo0GszNzTHCCFWqXXVVqowwwtzc3JYPfSQN1jDfafdrd82nM/P2zLwtM+/MzK9kZjMzP5mZ97R+Xr/7Rn3UuU1yiqnNX5zHeGzzF2qKqS3bKyXtj2G9046NNfuDYXJyMpeXl4fdjUOn2WwyMTHRtU1yhhmWWGKNNUYYYYopznFus75SqXDx4sW+vSWUdGPtidj66vrmUk3bCiv8Br/BscqxPc3kI+LlzJzcrs7LGhRgdHSU2dnZrrJznONY5Rjnz5/nWOVYV8ADzM7OGvDSPhnqO+2dtt0M4+YWyt4Mch+upL3Zj/NYGNIWSvXRzz/985u3ndRqNarV6pYPbzo/9KlWq57tKu2jYb/Tdk3+kOgM92+f+vYNH9tsNrf9BdmpXNLgdV56pHMi1rmUA+xpInajNflbe+u2DqKdgtyAlwbnP/2bk5v3P//fz2+pbwf3wsLCtu+0p6enmZ2d7fs7bWfyktQHNwv5tkG803Z3jSQdEPv9TtvlGknqgxvN3ofJmbwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFG3jIR8SDEXE5Iq5ExBODbk+S9J6BhnxE3AJ8Gfhl4F7g0xFx7yDblCS9Z9Az+fuBK5n515n5D8AzwMMDblOS1DLokL8D+H7H8dVW2aaImIuI5YhYvnbt2oC7I0lHy6BDPrYpy66DzPnMnMzMyePHjw+4O5J0tAw65K8Cd3Uc3wmsDrhNSVLLoEP+z4F7IuLuiPhx4FHg+QG3KUlquXWQL56Z70bEvwf+GLgF+GpmvjrINiVJ7xloyANk5h8CfzjodiRJW3nGqyQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWA9hXxEPBIRr0bEDyNi8rq6L0bElYi4HBGf6q2bkqS9uLXH578C/Crw3zoLI+Je4FHgI0AF+EZEjGfmD3psT5K0Cz3N5DPzUmZe3qbqYeCZzHwnM78LXAHu76UtSdLuDWpN/g7g+x3HV1tlkqR9dNPlmoj4BvChbaqezMzndnraNmW5w+vPAXMAJ06cuFl3JEm7cNOQz8wH9vC6V4G7Oo7vBFZ3eP15YB5gcnJy2/8IJEl7M6jlmueBRyPifRFxN3AP8NKA2pIk7aDXLZS/EhFXgX8OXIiIPwbIzFeBZ4G/Av4H8Fl31kjS/utpC2Vmfh34+g51TwFP9fL6kqTeeMarJBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBWsp5CPiP8YEd+JiIsR8fWIGOmo+2JEXImIyxHxqZ57KknatV5n8i8AH83MCaABfBEgIu4FHgU+AjwI/E5E3NJjW5KkXeop5DPzf2bmu63DPwXubN1/GHgmM9/JzO8CV4D7e2lLkrR7/VyT/zXgj1r37wC+31F3tVW2RUTMRcRyRCxfu3atj92RJN16swdExDeAD21T9WRmPtd6zJPAu8DX2k/b5vG53etn5jwwDzA5ObntYyRJe3PTkM/MB25UHxGngJPAJzOzHdJXgbs6HnYnsLrXTkqS9qbX3TUPAl8A/nVmrndUPQ88GhHvi4i7gXuAl3ppS5K0ezedyd/EbwPvA16ICIA/zcxfz8xXI+JZ4K/YWMb5bGb+oMe2JEm71FPIZ+bP3qDuKeCpXl5fkkrWbDYZHR39kcv3wjNeJWkIarUaExMTNBqNrvJGo8HExAS1Wq0v7RjykrTParUa9Xqd1dVVpqenN4O+0WgwPT3N6uoq9Xq9L0FvyEvSPmoHfFs76C9cuLAZ8G39CHpDXpL2SbPZZGFhoatshhnWV9c5efIk66vrzDDTVb+wsECz2dxzm4a8JO2T0dFRFhcXqVQqwEbAn+Y0ZzjDGGOc4QynOb0Z9JVKhcXFxZ4+hDXkJWkfjY+Pbwb9EkussMIYY5zlLGOMscIKSyxtBvz4+HhP7RnykrTPxsfHmZ+fZ4016tS76urUWWON+fn5ngMeDHlJ2neNRoO5uTlGGKFKtauuSpURRpibm9uyvXIvDHlJ2ked2ySnmNpconmMxzaXbqaY2rK9cq/ivWuKDd/k5GQuLy8PuxuSNBDNZpOJiYmubZIzzLDEEmusMcIIU0xxjnOb9ZVKhYsXL97ww9eIeDkzJ7ercyYvSftkdHSU2dnZrrJznONY5Rjnz5/nWOVYV8ADzM7O9rS7ptcLlEmSrvPlX/9fm/c/+18/0VXXPrmpfUJU5y6axcXFrhOiqtVqzydDGfKStM/awb2wsNC1TbIz6GdnZ/tyWQNDXpKGoFar8fjjj29ZihkfH7/pGvxuGPKS1GfXL9HsZKcg71fAgx+8SlLRDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUsAN1gbKIuAa8Pux+7IMPAm8PuxP7zDEfDY55OP5pZh7fruJAhfxRERHLO10xrlSO+WhwzAePyzWSVDBDXpIKZsgPx/ywOzAEjvlocMwHjGvyklQwZ/KSVDBDXpIKZsjvo4h4JCJejYgfRsTkdXVfjIgrEXE5Ij41rD4OQkQ82BrXlYh4Ytj9GYSI+GpEvBURr3SUfSAiXoiI11o/3z/MPvZbRNwVEYsRcan1e326VV7suCPiJyLipYj4y9aY663yAztmQ35/vQL8KvDNzsKIuBd4FPgI8CDwOxFxy/53r/9a4/gy8MvAvcCnW+Mtze+y8W/X6Qngxcy8B3ixdVySd4HPZ+aHgY8Dn23925Y87neAT2TmLwAfAx6MiI9zgMdsyO+jzLyUmZe3qXoYeCYz38nM7wJXgPv3t3cDcz9wJTP/OjP/AXiGjfEWJTO/CfzNdcUPA0+37j8NzOxnnwYtM9/IzL9o3f9/wCXgDgoed274+9bhba1bcoDHbMgfDHcA3+84vtoqK0HJY7uZn8nMN2AjEIGfHnJ/BiYixoBfBP6MwscdEbdExLeAt4AXMvNAj9nveO2ziPgG8KFtqp7MzOd2eto2ZaXsbS15bAIi4qeA3wc+l5l/F7HdP3k5MvMHwMciYgT4ekR8dMhduiFDvs8y84E9PO0qcFfH8Z3Aan96NHQlj+1m3oyI2zPzjYi4nY2ZX1Ei4jY2Av5rmfkHreLixw2QmWsRscTGZzEHdswu1xwMzwOPRsT7IuJu4B7gpSH3qV/+HLgnIu6OiB9n4wPm54fcp/3yPHCqdf8UsNM7uUMpNqbsXwEuZeZvdVQVO+6ION6awRMRPwk8AHyHAzxmz3jdRxHxK8B/AY4Da8C3MvNTrbongV9jY8fC5zLzj4bVz36LiH8F/GfgFuCrmfnUcHvUfxHxe8AUG5edfROoAueAZ4ETwPeARzLz+g9nD62I+JfA/wa+DfywVfybbKzLFznuiJhg44PVW9iYJD+bmf8hIkY5oGM25CWpYC7XSFLBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUsP8Pfzg4GNqH4TYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_data(centroids+2, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib.animation import FuncAnimation\n", "from IPython.display import HTML" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def do_one(d):\n", " if d: one_update(X)\n", " ax.clear()\n", " plot_data(centroids+2, X, n_samples, ax=ax)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create your own animation\n", "X = data.clone()\n", "fig,ax = plt.subplots()\n", "ani = FuncAnimation(fig, do_one, frames=5, interval=500, repeat=False)\n", "plt.close()\n", "HTML(ani.to_jshtml())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GPU batched algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To truly accelerate the algorithm, we need to be performing updates on a batch of points per iteration, instead of just one as we were doing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 2]), torch.Size([1500, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bs=5\n", "X = data.clone()\n", "x = X[:bs]\n", "x.shape,X.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dist_b(a,b): return (((a[None]-b[:,None])**2).sum(2)).sqrt()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.000, 3.899, 4.834, ..., 17.628, 22.610, 21.617],\n", " [ 3.899, 0.000, 4.978, ..., 21.499, 26.508, 25.500],\n", " [ 4.834, 4.978, 0.000, ..., 19.373, 24.757, 23.396],\n", " [ 3.726, 0.185, 4.969, ..., 21.335, 26.336, 25.333],\n", " [ 6.273, 5.547, 1.615, ..., 20.775, 26.201, 24.785]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_b(X, x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 1500])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dist_b(X, x).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 1500, 2]), torch.Size([5, 1, 2]), torch.Size([5, 1500, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X[None,:].shape, x[:,None].shape, (X[None,:]-x[:,None]).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.199, 0.030, 0.011, ..., 0.000, 0.000, 0.000],\n", " [ 0.030, 0.199, 0.009, ..., 0.000, 0.000, 0.000],\n", " [ 0.011, 0.009, 0.199, ..., 0.000, 0.000, 0.000],\n", " [ 0.035, 0.199, 0.009, ..., 0.000, 0.000, 0.000],\n", " [ 0.001, 0.004, 0.144, ..., 0.000, 0.000, 0.000]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight = gaussian(dist_b(X, x), 2)\n", "weight" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 1500]), torch.Size([1500, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight.shape,X.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 1500, 1]), torch.Size([1, 1500, 2]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight[...,None].shape, X[None].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num = (weight[...,None]*X[None]).sum(1)\n", "num.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[367.870, 386.231],\n", " [518.332, 588.680],\n", " [329.665, 330.782],\n", " [527.617, 598.217],\n", " [231.302, 234.155]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[367.870, 386.231],\n", " [518.332, 588.680],\n", " [329.665, 330.782],\n", " [527.617, 598.218],\n", " [231.302, 234.155]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.einsum('ij,jk->ik', weight, X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[367.870, 386.231],\n", " [518.332, 588.680],\n", " [329.665, 330.782],\n", " [527.617, 598.218],\n", " [231.302, 234.155]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weight@X" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 1])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "div = weight.sum(1, keepdim=True)\n", "div.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[26.376, 27.692],\n", " [26.101, 29.643],\n", " [28.892, 28.990],\n", " [26.071, 29.559],\n", " [29.323, 29.685]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num/div" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def meanshift(data, bs=500):\n", " n = len(data)\n", " X = data.clone()\n", " for it in range(5):\n", " for i in range(0, n, bs):\n", " s = slice(i, min(i+bs,n))\n", " weight = gaussian(dist_b(X, X[s]), 2.5)\n", "# weight = tri(dist_b(X, X[s]), 8)\n", " div = weight.sum(1, keepdim=True)\n", " X[s] = weight@X/div\n", " return X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Although each iteration still has to launch a new cuda kernel, there are now fewer iterations, and the acceleration from updating a batch of points more than makes up for it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = meanshift(data).cpu()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)\n" ] } ], "source": [ "%timeit -n 5 _=meanshift(data, 1250).cpu()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUzElEQVR4nO3dUWyk13mf8eeNpDhZpC1tehNrJG0pIGJq2WEciBBctBekrcBKtaiYAELli2KrGCQCuMI68IXl6GJmCggwUDRboHHQkrA3ujCiCki8EnaTprJKwi2QRKECZy1lvaNFTNkLCtJqEjYNCCiQ/faCM9TMktw1OTMc8vD5AQPOd87MnHOw5H/PnDnfN5GZSJLK9GPD7oAkaXAMeUkqmCEvSQUz5CWpYIa8JBXs1mF3oNMHP/jBHBsbG3Y3JOlQefnll9/OzOPb1R2okB8bG2N5eXnY3ZCkQyUiXt+pzuUaSSqYIV+AZrO5q3JJR4chf8jVajUmJiZoNBpd5Y1Gg4mJCWq12nA6JulAMOQPsVqtRr1eZ3V1lenp6c2gbzQaTE9Ps7q6Sr1eN+ilI8yQP6TaAd/WDvoLFy5sBnybQS8dXYb8IdRsNllYWOgqm2GG9dV1Tp48yfrqOjPMdNUvLCy4Ri8dQYb8ITQ6Osri4iKVSgXYCPjTnOYMZxhjjDOc4TSnN4O+UqmwuLjI6OjoEHstaRgM+UNqfHx8M+iXWGKFFcYY4yxnGWOMFVZYYmkz4MfHx4fdZUlDYMgfYuPj48zPz7PGGnXqXXV16qyxxvz8vAEvHWGG/CHWaDSYm5tjhBGqVLvqqlQZYYS5ubkt2yslHR2G/CHVuU1yiqnNJZrHeGxz6WaKqS3bKyUdLXGQvv5vcnIyvXbNzTWbTSYmJrq2Sc4wwxJLrLHGCCNMMcU5zm3WVyoVLl686IevUoEi4uXMnNyuzpn8ITQ6Osrs7GxX2TnOcaxyjPPnz3Oscqwr4AFmZ2cNeGkXSrlciCF/SNVqNarV99bh27toHnrooa7tlQDVatWToaRdKOpyIZl5YG733Xdf6jrVf/zebbvqajUrlUpevny5q/zy5ctZqVSyWq3uQyelclSr1QQS6Prbav9NtesO0t8WsJw75Kpr8gdd7Z903P+/2z6k2WxuuxSzU7mk7V1/uRDYeJc8Pz/P3Nxc1+dgcHDeJbsmX7idgtyAl350pV4u5EB9M5S2scPsXVJ/tS8X0t6a3L5cyMM8TJ06VaqMMQZsbHQ4LJcLcSYvSS0lXi6k55CPiJ+IiJci4i8j4tWIqLfKPxARL0TEa62f7++9u5I0WKVdLqQfM/l3gE9k5i8AHwMejIiPA08AL2bmPcCLrWNJOtBKu1xIzyHf2sHz963D21q3BB4Gnm6VPw3XfWIhSQdMiZcL6csWyoi4BXgZ+Fngy5n5hYhYy8yRjsf8bWZuWbKJiDlgDuDEiRP3vf766z33R5J26zBfLmTgWygz8weZ+THgTuD+iPjoLp47n5mTmTl5/PjxfnRHknat1MuF9HULZWauRcQS8CDwZkTcnplvRMTtwFv9bEuS9mLsiQub91e+9FBXXfvEpvYJUZ27aDq3V8LBORHqZvqxu+Z4RIy07v8k8ADwHeB54FTrYaeA53ptS5IGrX1dqOu3SXZurzwsAQ/9mcnfDjzdWpf/MeDZzDwfEX8CPBsRnwG+BzzSh7YkaeBqtRqPP/74lqWY8fHxA7EGvxteu0aSDjmvXSNJR5QhL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFexIhXyz2dxVuSQddkcm5Gu1GhMTE1u+k7HRaDAxMXForg0tSbtxJEK+VqtRr9e3fPlu55f21ut1g15ScYoP+XbAt7WD/sKFC11f5QUY9JKKU3TIN5tNFhYWuspmmGF9dZ2TJ0+yvrrODDNd9QsLC67RSypG0SE/Ojq6+Z2MsBHwpznNGc4wxhhnOMNpTm8Gffs7HQ/TV3tJ0o0UHfLQ/eW7SyyxwgpjjHGWs4wxxgorLLG05Ut7JakExYc8bAT9/Pw8a6xRp95VV6fOGmvMz88b8JKKcyRCvtFoMDc3xwgjVKl21VWpMsIIc3NzW7ZXStJhV3zId26TnGJqc4nmMR7bXLqZYmrL9kpJKkFk5rD7sGlycjKXl5f79nrNZpOJiYmubZIzzLDEEmusMcIIU0xxjnOb9ZVKhYsXL/rhq6RDIyJezszJ7eqKnsmPjo4yOzvbVXaOcxyrHOP8+fMcqxzrCniA2dlZA15SMXoO+Yi4KyIWI+JSRLwaEadb5R+IiBci4rXWz/f33t3dq9VqVKvvrcO3d9E89NBDXdsrAarVqidDSSrKrX14jXeBz2fmX0TEPwJejogXgH8HvJiZX4qIJ4AngC/0ob1tXfpnH968/+HvXOqqawf3wsJC1zbJ9vbK6elpZmdnDXhJxen7mnxEPAf8dus2lZlvRMTtwFJm/tyNntvLmvyNQr6t2WxuuxSzU7kkHQb7tiYfEWPALwJ/BvxMZr4B0Pr50zs8Zy4iliNi+dq1a/3szhY7BbkBL6lU/ViuASAifgr4feBzmfl3EfEjPS8z54F52JjJ77X9nWbvknSU9WUmHxG3sRHwX8vMP2gVv9lapqH1861+tCVJ+tH1Y3dNAF8BLmXmb3VUPQ+cat0/BTzXa1uSpN3px3LNvwD+LfDtiPhWq+w3gS8Bz0bEZ4DvAY/0oS1J0i70HPKZ+X+AnRbgP9nr60uS9q7oM14l6agz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEM+QI1m81dlUsqV19CPiK+GhFvRcQrHWUfiIgXIuK11s/396Mt3VitVmNiYoJGo9FV3mg0mJiYoFarDadjkoaiXzP53wUevK7sCeDFzLwHeLF1rAGq1WrU63VWV1eZnp7eDPpGo8H09DSrq6vU63WDXhqi/X6n3ZeQz8xvAn9zXfHDwNOt+08DM/1oS9trB3xbO+gvXLiwGfBtBr00HEN5p52ZfbkBY8ArHcdr19X/7Q7PmwOWgeUTJ06kdu/tt9/OSqWSwOZthpkcYSSBHGEkZ5jpqq9UKvn2228Pu+vSkVGtVrv+/i5fvpyZmZcvX+76+61Wq7t+bWA5d8rmnSp2e9tryHfe7rvvvl0PThs6f1FmmMlFFvMsZ3OMsTzL2VxkcTPoO3/BJA1eZ8B3Bv358+e3TND2EvQ3CvlB7q55MyJuB2j9fGuAbR154+PjLC4uUqlUWGKJFVYYY4yznGWMMVZYYYklKpUKi4uLjI+PD7vL0pHQbDZZWFjoKpthhvXVdU6ePMn66joz161mLyws9G2NfpAh/zxwqnX/FPDcANsSG0E/Pz/PGmvUqXfV1amzxhrz8/MGvLSPRkdHNydgsBHwpznNGc4wxhhnOMNpTm8GfXsiNjo62pf2+7WF8veAPwF+LiKuRsRngC8BvxQRrwG/1DrWADUaDebm5hhhhCrVrroqVUYYYW5ubsuHPpIGa5jvtPu1u+bTmXl7Zt6WmXdm5lcys5mZn8zMe1o/r999oz7q3CY5xdTmL85jPLb5CzXF1JbtlZL2x7DeacfGmv3BMDk5mcvLy8PuxqHTbDaZmJjo2iY5wwxLLLHGGiOMMMUU5zi3WV+pVLh48WLf3hJKurH2RGx9dX1zqaZthRV+g9/gWOXYnmbyEfFyZk5uV+dlDQowOjrK7OxsV9k5znGscozz589zrHKsK+ABZmdnDXhpnwz1nfZO226GcXMLZW8GuQ9X0t7sx3ksDGkLpfro55/++c3bTmq1GtVqdcuHN50f+lSrVc92lfbRsN9puyZ/SHSG+7dPffuGj202m9v+guxULmnwOi890jkR61zKAfY0EbvRmvytvXVbB9FOQW7AS4Pzn/7Nyc37n//v57fUt4N7YWFh23fa09PTzM7O9v2dtjN5SeqDm4V82yDeabu7RpIOiP1+p+1yjST1wY1m78PkTF6SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCDTzkI+LBiLgcEVci4olBtydJes9AQz4ibgG+DPwycC/w6Yi4d5BtSpLeM+iZ/P3Alcz868z8B+AZ4OEBtylJahl0yN8BfL/j+GqrbFNEzEXEckQsX7t2bcDdkaSjZdAhH9uUZddB5nxmTmbm5PHjxwfcHUk6WgYd8leBuzqO7wRWB9ymJKll0CH/58A9EXF3RPw48Cjw/IDblCS13DrIF8/MdyPi3wN/DNwCfDUzXx1km5Kk9ww05AEy8w+BPxx0O5KkrTzjVZIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVLCeQj4iHomIVyPihxExeV3dFyPiSkRcjohP9dZNSdJe3Nrj818BfhX4b52FEXEv8CjwEaACfCMixjPzBz22J0nahZ5m8pl5KTMvb1P1MPBMZr6Tmd8FrgD399KWJGn3BrUmfwfw/Y7jq60ySdI+uulyTUR8A/jQNlVPZuZzOz1tm7Lc4fXngDmAEydO3Kw7kqRduGnIZ+YDe3jdq8BdHcd3Aqs7vP48MA8wOTm57X8EkqS9GdRyzfPAoxHxvoi4G7gHeGlAbUmSdtDrFspfiYirwD8HLkTEHwNk5qvAs8BfAf8D+Kw7ayRp//W0hTIzvw58fYe6p4Cnenl9SVJvPONVkgpmyEtSwQx5SSqYIS9JBTPkJalghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkqmCEvSQUz5CWpYIa8JBXMkJekghnyklQwQ16SCmbIS1LBDHlJKpghL0kFM+QlqWCGvCQVzJCXpIIZ8pJUMENekgrWU8hHxH+MiO9ExMWI+HpEjHTUfTEirkTE5Yj4VM89lSTtWq8z+ReAj2bmBNAAvggQEfcCjwIfAR4EficibumxLUnSLvUU8pn5PzPz3dbhnwJ3tu4/DDyTme9k5neBK8D9vbQlSdq9fq7J/xrwR637dwDf76i72irbIiLmImI5IpavXbvWx+5Ikm692QMi4hvAh7apejIzn2s95kngXeBr7adt8/jc7vUzcx6YB5icnNz2MZKkvblpyGfmAzeqj4hTwEngk5nZDumrwF0dD7sTWN1rJyVJe9Pr7poHgS8A/zoz1zuqngcejYj3RcTdwD3AS720JUnavZvO5G/it4H3AS9EBMCfZuavZ+arEfEs8FdsLON8NjN/0GNbkqRd6inkM/Nnb1D3FPBUL68vSSVrNpuMjo7+yOV74RmvkjQEtVqNiYkJGo1GV3mj0WBiYoJardaXdgx5SdpntVqNer3O6uoq09PTm0HfaDSYnp5mdXWVer3el6A35CVpH7UDvq0d9BcuXNgM+LZ+BL0hL0n7pNlssrCw0FU2wwzrq+ucPHmS9dV1Zpjpql9YWKDZbO65TUNekvbJ6Ogoi4uLVCoVYCPgT3OaM5xhjDHOcIbTnN4M+kqlwuLiYk8fwhrykrSPxsfHN4N+iSVWWGGMMc5yljHGWGGFJZY2A358fLyn9gx5Sdpn4+PjzM/Ps8YadepddXXqrLHG/Px8zwEPhrwk7btGo8Hc3BwjjFCl2lVXpcoII8zNzW3ZXrkXhrwk7aPObZJTTG0u0TzGY5tLN1NMbdleuVfx3jXFhm9ycjKXl5eH3Q1JGohms8nExETXNskZZlhiiTXWGGGEKaY4x7nN+kqlwsWLF2/44WtEvJyZk9vVOZOXpH0yOjrK7OxsV9k5znGscozz589zrHKsK+ABZmdne9pd0+sFyiRJ1/nyr/+vzfuf/a+f6Kprn9zUPiGqcxfN4uJi1wlR1Wq155OhDHlJ2mft4F5YWOjaJtkZ9LOzs325rIEhL0lDUKvVePzxx7csxYyPj990DX43DHlJ6rPrl2h2slOQ9yvgwQ9eJalohrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkq2IG6QFlEXANeH3Y/9sEHgbeH3Yl95piPBsc8HP80M49vV3GgQv6oiIjlna4YVyrHfDQ45oPH5RpJKpghL0kFM+SHY37YHRgCx3w0OOYDxjV5SSqYM3lJKpghL0kFM+T3UUQ8EhGvRsQPI2LyurovRsSViLgcEZ8aVh8HISIebI3rSkQ8Mez+DEJEfDUi3oqIVzrKPhARL0TEa62f7x9mH/stIu6KiMWIuNT6vT7dKi923BHxExHxUkT8ZWvM9Vb5gR2zIb+/XgF+FfhmZ2FE3As8CnwEeBD4nYi4Zf+713+tcXwZ+GXgXuDTrfGW5nfZ+Lfr9ATwYmbeA7zYOi7Ju8DnM/PDwMeBz7b+bUse9zvAJzLzF4CPAQ9GxMc5wGM25PdRZl7KzMvbVD0MPJOZ72Tmd4ErwP3727uBuR+4kpl/nZn/ADzDxniLkpnfBP7muuKHgadb958GZvazT4OWmW9k5l+07v8/4BJwBwWPOzf8fevwttYtOcBjNuQPhjuA73ccX22VlaDksd3Mz2TmG7ARiMBPD7k/AxMRY8AvAn9G4eOOiFsi4lvAW8ALmXmgx+x3vPZZRHwD+NA2VU9m5nM7PW2bslL2tpY8NgER8VPA7wOfy8y/i9jun7wcmfkD4GMRMQJ8PSI+OuQu3ZAh32eZ+cAennYVuKvj+E5gtT89GrqSx3Yzb0bE7Zn5RkTczsbMrygRcRsbAf+1zPyDVnHx4wbIzLWIWGLjs5gDO2aXaw6G54FHI+J9EXE3cA/w0pD71C9/DtwTEXdHxI+z8QHz80Pu0355HjjVun8K2Omd3KEUG1P2rwCXMvO3OqqKHXdEHG/N4ImInwQeAL7DAR6zZ7zuo4j4FeC/AMeBNeBbmfmpVt2TwK+xsWPhc5n5R8PqZ79FxL8C/jNwC/DVzHxquD3qv4j4PWCKjcvOvglUgXPAs8AJ4HvAI5l5/Yezh1ZE/EvgfwPfBn7YKv5NNtblixx3REyw8cHqLWxMkp/NzP8QEaMc0DEb8pJUMJdrJKlghrwkFcyQl6SCGfKSVDBDXpIKZshLUsEMeUkq2P8HgsA4GJpUh/wAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_data(centroids+2, X, n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Homework:** implement k-means clustering, dbscan, locality sensitive hashing, or some other clustering, fast nearest neighbors, or similar algorithm of your choice, on the GPU. Check if your version is faster than a pure python or CPU version.\n", "\n", "Bonus: Implement it in APL too!\n", "\n", "Super bonus: Invent a new meanshift algorithm which picks only the closest points, to avoid quadratic time.\n", "\n", "Super super bonus: Publish a paper that describes it :D" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }