{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.metrics.pairwise import euclidean_distances\n",
    "from sklearn.cluster import KMeans as skKmeans"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 1\n",
    "- similat to scikit-learn init=\"random\", algorithm=\"full\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KMeans():\n",
    "    def __init__(self, n_clusters=8,\n",
    "                 max_iter=300, tol=1e-4, random_state=0):\n",
    "        self.n_clusters = n_clusters\n",
    "        self.max_iter = max_iter\n",
    "        self.tol = tol\n",
    "        self.random_state = random_state\n",
    "\n",
    "    def _labels_inertia(self, X, centers):\n",
    "        labels = np.zeros(X.shape[0])\n",
    "        inertia = 0\n",
    "        for sample_idx in range(X.shape[0]):\n",
    "            min_dis = np.inf\n",
    "            for center_idx in range(self.n_clusters):\n",
    "                d = np.sum(np.square(X[sample_idx] - centers[center_idx]))\n",
    "                if d < min_dis:\n",
    "                    min_dis = d\n",
    "                    labels[sample_idx] = center_idx\n",
    "            inertia += min_dis\n",
    "        return labels, inertia\n",
    "\n",
    "    def fit(self, X):\n",
    "        rng = np.random.RandomState(self.random_state)\n",
    "        # consistent with scikit-learn\n",
    "        tol = np.mean(np.var(X, axis=0)) * self.tol\n",
    "        centers = X[rng.permutation(X.shape[0])[:self.n_clusters]]\n",
    "        for i in range(self.max_iter):\n",
    "            centers_old = centers.copy()\n",
    "            labels, inertia = self._labels_inertia(X, centers)\n",
    "            for center_idx in range(self.n_clusters):\n",
    "                centers[center_idx] = np.mean(X[labels == center_idx], axis=0)\n",
    "            center_shift_total = np.sum(np.square(centers_old - centers))\n",
    "            if center_shift_total <= tol:\n",
    "                break\n",
    "        if center_shift_total > 0:\n",
    "            labels, inertia = self._labels_inertia(X, centers)\n",
    "        self.cluster_centers_ = centers\n",
    "        self.labels_ = labels\n",
    "        self.inertia_ = inertia\n",
    "        self.n_iter_ = i + 1\n",
    "        return self\n",
    "\n",
    "    def predict(self, X):\n",
    "        return self._labels_inertia(X, self.cluster_centers_)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, _ = load_iris(return_X_y=True)\n",
    "clf1 = KMeans(n_clusters=3, random_state=0).fit(X)\n",
    "clf2 = skKmeans(n_clusters=3, init=\"random\", n_init=1, algorithm=\"full\", random_state=0).fit(X)\n",
    "assert np.allclose(clf1.cluster_centers_, clf2.cluster_centers_)\n",
    "assert np.array_equal(clf1.labels_, clf2.labels_)\n",
    "assert np.allclose(clf1.inertia_, clf2.inertia_)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementation 2\n",
    "- similat to scikit-learn init=\"k-means++\", algorithm=\"full\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KMeans():\n",
    "    def __init__(self, n_clusters=8,\n",
    "                 max_iter=300, tol=1e-4, random_state=0):\n",
    "        self.n_clusters = n_clusters\n",
    "        self.max_iter = max_iter\n",
    "        self.tol = tol\n",
    "        self.random_state = random_state\n",
    "\n",
    "    def _labels_inertia(self, X, centers):\n",
    "        labels = np.zeros(X.shape[0])\n",
    "        inertia = 0\n",
    "        for sample_idx in range(X.shape[0]):\n",
    "            min_dis = np.inf\n",
    "            for center_idx in range(self.n_clusters):\n",
    "                d = np.sum(np.square(X[sample_idx] - centers[center_idx]))\n",
    "                if d < min_dis:\n",
    "                    min_dis = d\n",
    "                    labels[sample_idx] = center_idx\n",
    "            inertia += min_dis\n",
    "        return labels, inertia\n",
    "\n",
    "    def fit(self, X):\n",
    "        # consistent with scikit-learn\n",
    "        tol = np.mean(np.var(X, axis=0)) * self.tol\n",
    "        rng = np.random.RandomState(self.random_state)\n",
    "        centers = np.empty((self.n_clusters, X.shape[1]), dtype=X.dtype)\n",
    "        centers[0] = X[rng.randint(X.shape[0])]\n",
    "        closest_dist_sq = euclidean_distances(centers[0, np.newaxis], X, squared=True)\n",
    "        n_local_trials = 2 + int(np.log(self.n_clusters))\n",
    "        for i in range(1, self.n_clusters):\n",
    "            rand_vals = rng.random_sample(n_local_trials) * np.sum(closest_dist_sq)\n",
    "            candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq), rand_vals)\n",
    "            distance_to_candidates = euclidean_distances(X[candidate_ids], X)\n",
    "            distance_to_candidates = np.minimum(closest_dist_sq, distance_to_candidates)\n",
    "            candidates_pot = distance_to_candidates.sum(axis=1)\n",
    "            best_candidate = np.argmin(candidates_pot)\n",
    "            closest_dist_sq = distance_to_candidates[best_candidate]\n",
    "            centers[i] = X[candidate_ids[best_candidate]]\n",
    "        for i in range(self.max_iter):\n",
    "            centers_old = centers.copy()\n",
    "            labels, inertia = self._labels_inertia(X, centers)\n",
    "            for center_idx in range(self.n_clusters):\n",
    "                centers[center_idx] = np.mean(X[labels == center_idx], axis=0)\n",
    "            center_shift_total = np.sum(np.square(centers_old - centers))\n",
    "            if center_shift_total <= tol:\n",
    "                break\n",
    "        if center_shift_total > 0:\n",
    "            labels, inertia = self._labels_inertia(X, centers)\n",
    "        self.cluster_centers_ = centers\n",
    "        self.labels_ = labels\n",
    "        self.inertia_ = inertia\n",
    "        self.n_iter_ = i + 1\n",
    "        return self\n",
    "\n",
    "    def predict(self, X):\n",
    "        return self._labels_inertia(X, self.cluster_centers_)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, _ = load_iris(return_X_y=True)\n",
    "clf1 = KMeans(n_clusters=3, random_state=0).fit(X)\n",
    "clf2 = skKmeans(n_clusters=3, n_init=1, algorithm=\"full\", random_state=0).fit(X)\n",
    "assert np.allclose(clf1.cluster_centers_, clf2.cluster_centers_)\n",
    "assert np.array_equal(clf1.labels_, clf2.labels_)\n",
    "assert np.allclose(clf1.inertia_, clf2.inertia_)\n",
    "pred1 = clf1.predict(X)\n",
    "pred2 = clf2.predict(X)\n",
    "assert np.array_equal(pred1, pred2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}