{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from heapq import heappush, heappushpop\n",
    "from scipy.cluster.hierarchy import linkage\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.cluster import AgglomerativeClustering as skAgglomerativeClustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AgglomerativeClustering():\n",
    "    def __init__(self, n_clusters=2, linkage=\"ward\"):\n",
    "        self.n_clusters = n_clusters\n",
    "        self.linkage = linkage\n",
    "\n",
    "    def _get_descendent(self, node, n_samples):\n",
    "        ind = [node]\n",
    "        ret = []\n",
    "        while len(ind) > 0:\n",
    "            i = ind.pop()\n",
    "            if i < n_samples:\n",
    "                ret.append(i)\n",
    "            else:\n",
    "                ind.extend(self.children_[i - n_samples])\n",
    "        return ret\n",
    "\n",
    "    def fit(self, X):\n",
    "        Z = linkage(X, method=self.linkage)\n",
    "        self.children_ = Z[:, :2].astype(np.int)\n",
    "        nodes = []\n",
    "        heappush(nodes, -(X.shape[0] * 2 - 2))  # root node\n",
    "        for _ in range(self.n_clusters - 1):\n",
    "            these_children = self.children_[-nodes[0] - X.shape[0]]\n",
    "            heappush(nodes, -these_children[0])\n",
    "            heappushpop(nodes, -these_children[1])\n",
    "        label = np.zeros(X.shape[0])\n",
    "        for i, node in enumerate(nodes):\n",
    "            label[self._get_descendent(-node, X.shape[0])] = i\n",
    "        self.labels_ = label\n",
    "        return self\n",
    "\n",
    "    def fit_predict(self, X):\n",
    "        self.fit(X)\n",
    "        return self.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in [\"ward\", \"complete\", \"average\", \"single\"]:\n",
    "    X, _ = load_iris(return_X_y=True)\n",
    "    clf1 = AgglomerativeClustering(n_clusters=3, linkage=method).fit(X)\n",
    "    clf2 = skAgglomerativeClustering(n_clusters=3, linkage=method).fit(X)\n",
    "    assert np.array_equal(clf1.children_, clf2.children_)\n",
    "    assert np.array_equal(clf1.labels_, clf2.labels_)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "dev"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}