{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.datasets import load_iris\n", "from sklearn.metrics.pairwise import euclidean_distances\n", "from sklearn.cluster import KMeans as skKmeans" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementation 1\n", "- similat to scikit-learn init=\"random\", algorithm=\"full\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class KMeans():\n", " def __init__(self, n_clusters=8,\n", " max_iter=300, tol=1e-4, random_state=0):\n", " self.n_clusters = n_clusters\n", " self.max_iter = max_iter\n", " self.tol = tol\n", " self.random_state = random_state\n", "\n", " def _labels_inertia(self, X, centers):\n", " labels = np.zeros(X.shape[0])\n", " inertia = 0\n", " for sample_idx in range(X.shape[0]):\n", " min_dis = np.inf\n", " for center_idx in range(self.n_clusters):\n", " d = np.sum(np.square(X[sample_idx] - centers[center_idx]))\n", " if d < min_dis:\n", " min_dis = d\n", " labels[sample_idx] = center_idx\n", " inertia += min_dis\n", " return labels, inertia\n", "\n", " def fit(self, X):\n", " rng = np.random.RandomState(self.random_state)\n", " # consistent with scikit-learn\n", " tol = np.mean(np.var(X, axis=0)) * self.tol\n", " centers = X[rng.permutation(X.shape[0])[:self.n_clusters]]\n", " for i in range(self.max_iter):\n", " centers_old = centers.copy()\n", " labels, inertia = self._labels_inertia(X, centers)\n", " for center_idx in range(self.n_clusters):\n", " centers[center_idx] = np.mean(X[labels == center_idx], axis=0)\n", " center_shift_total = np.sum(np.square(centers_old - centers))\n", " if center_shift_total <= tol:\n", " break\n", " if center_shift_total > 0:\n", " labels, inertia = self._labels_inertia(X, centers)\n", " self.cluster_centers_ = centers\n", " self.labels_ = labels\n", " self.inertia_ = inertia\n", " self.n_iter_ = i + 1\n", " return self\n", "\n", " def predict(self, X):\n", " return self._labels_inertia(X, self.cluster_centers_)[0]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X, _ = load_iris(return_X_y=True)\n", "clf1 = KMeans(n_clusters=3, random_state=0).fit(X)\n", "clf2 = skKmeans(n_clusters=3, init=\"random\", n_init=1, algorithm=\"full\", random_state=0).fit(X)\n", "assert np.allclose(clf1.cluster_centers_, clf2.cluster_centers_)\n", "assert np.array_equal(clf1.labels_, clf2.labels_)\n", "assert np.allclose(clf1.inertia_, clf2.inertia_)\n", "pred1 = clf1.predict(X)\n", "pred2 = clf2.predict(X)\n", "assert np.array_equal(pred1, pred2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementation 2\n", "- similat to scikit-learn init=\"k-means++\", algorithm=\"full\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class KMeans():\n", " def __init__(self, n_clusters=8,\n", " max_iter=300, tol=1e-4, random_state=0):\n", " self.n_clusters = n_clusters\n", " self.max_iter = max_iter\n", " self.tol = tol\n", " self.random_state = random_state\n", "\n", " def _labels_inertia(self, X, centers):\n", " labels = np.zeros(X.shape[0])\n", " inertia = 0\n", " for sample_idx in range(X.shape[0]):\n", " min_dis = np.inf\n", " for center_idx in range(self.n_clusters):\n", " d = np.sum(np.square(X[sample_idx] - centers[center_idx]))\n", " if d < min_dis:\n", " min_dis = d\n", " labels[sample_idx] = center_idx\n", " inertia += min_dis\n", " return labels, inertia\n", "\n", " def fit(self, X):\n", " # consistent with scikit-learn\n", " tol = np.mean(np.var(X, axis=0)) * self.tol\n", " rng = np.random.RandomState(self.random_state)\n", " centers = np.empty((self.n_clusters, X.shape[1]), dtype=X.dtype)\n", " centers[0] = X[rng.randint(X.shape[0])]\n", " closest_dist_sq = euclidean_distances(centers[0, np.newaxis], X, squared=True)\n", " n_local_trials = 2 + int(np.log(self.n_clusters))\n", " for i in range(1, self.n_clusters):\n", " rand_vals = rng.random_sample(n_local_trials) * np.sum(closest_dist_sq)\n", " candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq), rand_vals)\n", " distance_to_candidates = euclidean_distances(X[candidate_ids], X)\n", " distance_to_candidates = np.minimum(closest_dist_sq, distance_to_candidates)\n", " candidates_pot = distance_to_candidates.sum(axis=1)\n", " best_candidate = np.argmin(candidates_pot)\n", " closest_dist_sq = distance_to_candidates[best_candidate]\n", " centers[i] = X[candidate_ids[best_candidate]]\n", " for i in range(self.max_iter):\n", " centers_old = centers.copy()\n", " labels, inertia = self._labels_inertia(X, centers)\n", " for center_idx in range(self.n_clusters):\n", " centers[center_idx] = np.mean(X[labels == center_idx], axis=0)\n", " center_shift_total = np.sum(np.square(centers_old - centers))\n", " if center_shift_total <= tol:\n", " break\n", " if center_shift_total > 0:\n", " labels, inertia = self._labels_inertia(X, centers)\n", " self.cluster_centers_ = centers\n", " self.labels_ = labels\n", " self.inertia_ = inertia\n", " self.n_iter_ = i + 1\n", " return self\n", "\n", " def predict(self, X):\n", " return self._labels_inertia(X, self.cluster_centers_)[0]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "X, _ = load_iris(return_X_y=True)\n", "clf1 = KMeans(n_clusters=3, random_state=0).fit(X)\n", "clf2 = skKmeans(n_clusters=3, n_init=1, algorithm=\"full\", random_state=0).fit(X)\n", "assert np.allclose(clf1.cluster_centers_, clf2.cluster_centers_)\n", "assert np.array_equal(clf1.labels_, clf2.labels_)\n", "assert np.allclose(clf1.inertia_, clf2.inertia_)\n", "pred1 = clf1.predict(X)\n", "pred2 = clf2.predict(X)\n", "assert np.array_equal(pred1, pred2)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }