{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from heapq import heappush, heappushpop\n", "from scipy.cluster.hierarchy import linkage\n", "from sklearn.datasets import load_iris\n", "from sklearn.cluster import AgglomerativeClustering as skAgglomerativeClustering" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class AgglomerativeClustering():\n", " def __init__(self, n_clusters=2, linkage=\"ward\"):\n", " self.n_clusters = n_clusters\n", " self.linkage = linkage\n", "\n", " def _get_descendent(self, node, n_samples):\n", " ind = [node]\n", " ret = []\n", " while len(ind) > 0:\n", " i = ind.pop()\n", " if i < n_samples:\n", " ret.append(i)\n", " else:\n", " ind.extend(self.children_[i - n_samples])\n", " return ret\n", "\n", " def fit(self, X):\n", " Z = linkage(X, method=self.linkage)\n", " self.children_ = Z[:, :2].astype(np.int)\n", " nodes = []\n", " heappush(nodes, -(X.shape[0] * 2 - 2)) # root node\n", " for _ in range(self.n_clusters - 1):\n", " these_children = self.children_[-nodes[0] - X.shape[0]]\n", " heappush(nodes, -these_children[0])\n", " heappushpop(nodes, -these_children[1])\n", " label = np.zeros(X.shape[0])\n", " for i, node in enumerate(nodes):\n", " label[self._get_descendent(-node, X.shape[0])] = i\n", " self.labels_ = label\n", " return self\n", "\n", " def fit_predict(self, X):\n", " self.fit(X)\n", " return self.labels_" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "for method in [\"ward\", \"complete\", \"average\", \"single\"]:\n", " X, _ = load_iris(return_X_y=True)\n", " clf1 = AgglomerativeClustering(n_clusters=3, linkage=method).fit(X)\n", " clf2 = skAgglomerativeClustering(n_clusters=3, linkage=method).fit(X)\n", " assert np.array_equal(clf1.children_, clf2.children_)\n", " assert np.array_equal(clf1.labels_, clf2.labels_)" ] } ], "metadata": { "kernelspec": { "display_name": "dev", "language": "python", "name": "dev" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }