{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from scipy.spatial.distance import pdist, squareform\n", "from sklearn.datasets import load_iris\n", "from sklearn.cluster import DBSCAN as skDBSCAN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementation 1\n", "- based on brute force\n", "- similar to scikit-learn algorithm='brute'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class DBSCAN():\n", " def __init__(self, eps=0.5, min_samples=5):\n", " self.eps = eps\n", " self.min_samples = min_samples\n", "\n", " def fit(self, X):\n", " dist = squareform(pdist(X))\n", " neighbors = np.array([np.where(d <= self.eps)[0] for d in dist])\n", " n_neighbors = np.array([len(neighbor) for neighbor in neighbors])\n", " core_samples = n_neighbors >= self.min_samples\n", " labels = np.full(X.shape[0], -1)\n", " label_num = 0\n", " stack = []\n", " for i in range(len(labels)):\n", " if labels[i] != -1 or not core_samples[i]:\n", " continue\n", " stack.append(i)\n", " while len(stack):\n", " cur = stack.pop()\n", " if labels[cur] == -1:\n", " labels[cur] = label_num\n", " if core_samples[cur]:\n", " for neighbor in neighbors[cur]:\n", " if labels[neighbor] == -1 and neighbor not in stack:\n", " stack.append(neighbor)\n", " label_num += 1\n", " self.core_sample_indices_ = np.where(core_samples)[0]\n", " self.labels_= labels\n", " return self\n", "\n", " def fit_predict(self, X):\n", " self.fit(X)\n", " return self.labels_" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, _ = load_iris(return_X_y=True)\n", "clust1 = DBSCAN().fit(X)\n", "clust2 = skDBSCAN().fit(X)\n", "assert np.allclose(clust1.core_sample_indices_, clust2.core_sample_indices_)\n", "assert np.array_equal(clust1.labels_, clust2.labels_)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }