{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "explainable_kmeans.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "TuTD6V71BLs8", "colab_type": "text" }, "source": [ "### created by Takuya Matsuda at YNU" ] }, { "cell_type": "code", "metadata": { "id": "NQ8_Pv0X7ZIP", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "3ff2095e-a18e-4be8-95f8-3d15d1cf9637" }, "source": [ "!pip install graphviz" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: graphviz in /usr/local/lib/python3.6/dist-packages (0.10.1)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "evtU8xgl52Wf", "colab_type": "text" }, "source": [ "

\n", "# Explainable k-means" ] }, { "cell_type": "code", "metadata": { "id": "vO0mgUpX5j4x", "colab_type": "code", "colab": {} }, "source": [ "import queue,graphviz\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from graphviz import Digraph\n", "from sklearn.cluster import KMeans\n", "from sklearn.tree import export_graphviz,DecisionTreeClassifier" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "TM5mgXdl57nV", "colab_type": "text" }, "source": [ "

\n", "## Define funcs for k=2\n", "I did not create Class of their algorithm, so this is a little experiment for me." ] }, { "cell_type": "code", "metadata": { "id": "N8o4DrKq5zRk", "colab_type": "code", "colab": {} }, "source": [ "#ソートと動的計画法によって,最適な分割を求める (k=2)\n", "#this func is main func and get best split condition using sort and dynamic programming\n", "def optimal_threshold_2means(X):\n", " bests_split = {'cost':np.inf,'coordinate':None,'threshold':None}\n", " data_num = X.shape[0]\n", " data_dimentions = X.shape[1]\n", " u = np.sum(X*X)\n", " \n", " for i in range(data_dimentions):\n", " s = np.zeros(data_dimentions)\n", " r = np.sum(X,axis=0)\n", " ith_sorted_X = X[X[:,i].argsort(), :]\n", " for j,data in enumerate(ith_sorted_X[:-1]):\n", " s += data\n", " r -= data\n", " cost = u - np.sum(s*s)/(j+1) -np.sum(r*r)/(data_num-j-1)\n", " #print(cost)\n", " if cost < bests_split['cost'] and X[j][i]!=X[j+1][i]:\n", " bests_split['cost'] = cost\n", " bests_split['coordinate'] = i\n", " bests_split['threshold'] = data[i]\n", " \n", " return bests_split\n", "\n", "#最適な分割に基づいてクラスタリング(k=2)\n", "#this func is for clustering datasets based on best splits got from above func\n", "def clustering_2means_by_tree(bests_split,X):\n", " cluster = np.ones(X.shape[0])\n", " for i,data in enumerate(X):\n", " if(data[bests_split['coordinate']]>bests_split['threshold']):\n", " cluster[i] = 0\n", " return cluster\n", "\n", "#得られた分割の中心座標を求める(k=2)\n", "#this func is for calculating center points\n", "def get_mean(X,approx_labels):\n", " res=[]\n", " for k in range(len(np.unique(approx_labels))):\n", " n = 0\n", " mean = np.zeros(X.shape[1])\n", " for i,data in enumerate(X):\n", " if(approx_labels[i]==k):\n", " mean+=data\n", " n+=1\n", " res.append(mean/n)\n", " return np.array(res)\n", "\n", "#近似の比率を計算する,論文ではkmeansの場合,上界は4である(k=2).\n", "#this func is for calculating approximation ratio\n", "def approx_score(approx_labels,kmeans_model,X):\n", " kmeans_cost = 0\n", " kmeans_label = kmeans_model.labels_\n", " kmeans_centers = kmeans_model.cluster_centers_\n", " for i,data in enumerate(X):\n", " kmeans_cost += np.sum((data-kmeans_centers[kmeans_label[i]])*(data-kmeans_centers[kmeans_label[i]]))\n", " \n", " approx_cost = 0\n", " mean = get_mean(X,approx_labels)\n", " for k in range(kmeans_model.n_clusters):\n", " for i,data in enumerate(X):\n", " if(approx_labels[i]==k):\n", " approx_cost += np.sum((data-mean[k])*(data-mean[k]))\n", " print(kmeans_cost)\n", " print(approx_cost)\n", " return approx_cost/kmeans_cost" ], "execution_count": 76, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ZRJETWfD6ALL", "colab_type": "text" }, "source": [ "

\n", "## Two Datasets\n", "First is from [this good article](https://qiita.com/ynakayama/items/1223b6844a1a044e2e3b).\n", "\n", "Second is from uci repo and I get the code from [this good article](https://pythondatascience.plavox.info/scikit-learn/%E3%82%AF%E3%83%A9%E3%82%B9%E3%82%BF%E5%88%86%E6%9E%90-k-means)." ] }, { "cell_type": "code", "metadata": { "id": "MfULm2_S5_Cd", "colab_type": "code", "colab": {} }, "source": [ "# First dataset\n", "# 生徒の国語・数学・英語の各得点を配列として与える\n", "X = np.array([\n", " [ 80, 85, 100 ],\n", " [ 96, 100, 100 ],\n", " [ 54, 83, 98 ],\n", " [ 80, 98, 98 ],\n", " [ 90, 92, 91 ],\n", " [ 84, 78, 82 ],\n", " [ 79, 100, 96 ],\n", " [ 88, 92, 92 ],\n", " [ 98, 73, 72 ],\n", " [ 75, 84, 85 ],\n", " [ 92, 100, 96 ],\n", " [ 96, 92, 90 ],\n", " [ 99, 76, 91 ],\n", " [ 75, 82, 88 ],\n", " [ 90, 94, 94 ],\n", " [ 54, 84, 87 ],\n", " [ 92, 89, 62 ],\n", " [ 88, 94, 97 ],\n", " [ 42, 99, 80 ],\n", " [ 70, 98, 70 ],\n", " [ 94, 78, 83 ],\n", " [ 52, 73, 87 ],\n", " [ 94, 88, 72 ],\n", " [ 70, 73, 80 ],\n", " [ 95, 84, 90 ],\n", " [ 95, 88, 84 ],\n", " [ 75, 97, 89 ],\n", " [ 49, 81, 86 ],\n", " [ 83, 72, 80 ],\n", " [ 75, 73, 88 ],\n", " [ 79, 82, 76 ],\n", " [ 100, 77, 89 ],\n", " [ 88, 63, 79 ],\n", " [ 100, 50, 86 ],\n", " [ 55, 96, 84 ],\n", " [ 92, 74, 77 ],\n", " [ 97, 50, 73 ],\n", " ])" ], "execution_count": 77, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "YTfJApum6Kh1", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "d0d6bd72-0588-40c9-e006-585f1314b9cf" }, "source": [ "#second datasets\n", "cust_df = pd.read_csv(\"http://pythondatascience.plavox.info/wp-content/uploads/2016/05/Wholesale_customers_data.csv\")\n", "del(cust_df['Channel'])\n", "del(cust_df['Region'])\n", "cust_array = np.array([cust_df['Fresh'].tolist(),\n", " cust_df['Milk'].tolist(),\n", " cust_df['Grocery'].tolist(),\n", " cust_df['Frozen'].tolist(),\n", " cust_df['Milk'].tolist(),\n", " cust_df['Detergents_Paper'].tolist(),\n", " cust_df['Delicassen'].tolist()\n", " ], np.int32)\n", "cust_array = cust_array.T\n", "cust_array.shape" ], "execution_count": 78, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(440, 7)" ] }, "metadata": { "tags": [] }, "execution_count": 78 } ] }, { "cell_type": "markdown", "metadata": { "id": "2e3Q7CEq6RKa", "colab_type": "text" }, "source": [ "

\n", "## First data and k=2 " ] }, { "cell_type": "code", "metadata": { "id": "gjNuH4Cq6NLo", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "outputId": "80b02360-68aa-477b-fb58-e0e7492ecc24" }, "source": [ "#K-meansクラスタリングをおこなう\n", "kmeans_model = KMeans(n_clusters=2, random_state=10).fit(X)\n", "#分類先となったラベルを取得する\n", "labels = kmeans_model.labels_\n", "\n", "#提案手法による近似アルゴリズムで取得する\n", "bests_split = optimal_threshold_2means(X)\n", "approx_labels = clustering_2means_by_tree(bests_split,X)\n", "print(bests_split)\n", "print(approx_labels)" ], "execution_count": 79, "outputs": [ { "output_type": "stream", "text": [ "{'cost': 11316.030172413797, 'coordinate': 0, 'threshold': 70}\n", "[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 0. 1.\n", " 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Qw2hFW3X6TGF", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "outputId": "ad450568-5d37-4014-d419-5f2d1e55ced6" }, "source": [ "approx_score(approx_labels,kmeans_model,X) #近似アルゴリズムのコスト ➗ k-meansのコスト" ], "execution_count": 80, "outputs": [ { "output_type": "stream", "text": [ "11316.030172413793\n", "11316.030172413793\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "1.0" ] }, "metadata": { "tags": [] }, "execution_count": 80 } ] }, { "cell_type": "markdown", "metadata": { "id": "2QDeGFmP6Yr4", "colab_type": "text" }, "source": [ "

\n", "## Second data and k=2" ] }, { "cell_type": "code", "metadata": { "id": "UCkO03gN6VXa", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "b21c8de2-f566-403b-bf73-5a55fcf42ff4" }, "source": [ "#K-meansクラスタリングをおこなう\n", "kmeans_model = KMeans(n_clusters=2, random_state=10).fit(cust_array)\n", "#分類先となったラベルを取得する\n", "labels = kmeans_model.labels_\n", "\n", "#提案手法による近似アルゴリズムで取得する\n", "bests_split = optimal_threshold_2means(cust_array)\n", "approx_labels = clustering_2means_by_tree(bests_split,cust_array)\n", "print(bests_split)" ], "execution_count": 81, "outputs": [ { "output_type": "stream", "text": [ "{'cost': 41288343652.5358, 'coordinate': 2, 'threshold': 16483}\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "CT-5ZyGP6aly", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "outputId": "ca24df6d-ded3-44e9-8313-952323f7e760" }, "source": [ "approx_score(approx_labels,kmeans_model,cust_array) #近似アルゴリズムのコスト ➗ k-meansのコスト" ], "execution_count": 82, "outputs": [ { "output_type": "stream", "text": [ "132340344661.33641\n", "135777624164.53575\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "1.0259730281948067" ] }, "metadata": { "tags": [] }, "execution_count": 82 } ] }, { "cell_type": "markdown", "metadata": { "id": "mKb_viSO6e53", "colab_type": "text" }, "source": [ "

\n", "## IMM Algorithm\n", "I'm working on this...." ] }, { "cell_type": "markdown", "metadata": { "id": "uayor9-46g6H", "colab_type": "text" }, "source": [ "## Define funcs\n", "I did not create Class of their algorithm, so this is a little experiment for me." ] }, { "cell_type": "code", "metadata": { "id": "ssapF7ca6cIQ", "colab_type": "code", "colab": {} }, "source": [ "class TreeNode:\n", " def __init__(self, cluster=None, left=None, right=None,condition=None):\n", " self.cluster = cluster\n", " self.left = left\n", " self.right = right \n", " self.condition = (0,0) #(i,threshold) x_i <= threshold or x_i > threshold \n", "\n", "def minimum_center(i,labels,centers):\n", " minimum = np.inf\n", " for j in labels:\n", " minimum = min(minimum, centers[j][i])\n", " return minimum\n", "\n", "def maximum_center(i,labels,centers):\n", " maximum = -np.inf\n", " for j in labels:\n", " maximum = max(maximum, centers[j][i])\n", " return maximum \n", "\n", "def mistake(x,center,i,threshold):\n", " return 0 if ((x[i]<=threshold) == (center[i]<=threshold)) else 1\n", "\n", "def delete_mistakes_data(X,labels,centers,i,threshold):\n", " new_data = []\n", " new_labels=[]\n", " for idx,x in enumerate(X):\n", " if(mistake(x,centers[labels[idx]],i,threshold)==0):\n", " new_data.append(x)\n", " new_labels.append(labels[idx])\n", " return np.array(new_data),np.array(new_labels)\n", " \n", "def make_next_data(X,labels,i,threshold):\n", " l_data=[]\n", " l_labels=[]\n", " r_data=[]\n", " r_labels=[]\n", " for idx,x in enumerate(X):\n", " if(x[i]<=threshold):\n", " l_data.append(x)\n", " l_labels.append(labels[idx])\n", " else:\n", " r_data.append(x)\n", " r_labels.append(labels[idx])\n", " \n", " return np.array(l_data),np.array(l_labels),np.array(r_data),np.array(r_labels)\n", "\n", "def count_mistakes(X,l,i,labels,centers):\n", " cnt=0\n", " for idx,x in enumerate(X):\n", " if(mistake(x,centers[labels[idx]],i,l[i])==1):\n", " cnt+=1\n", " return cnt\n", "\n", "def get_best_splits(X,l,r,labels,centers):\n", " bests_split = {'mistake':np.inf,'coordinate':None,'threshold':None}\n", " data_dimentions = X.shape[1]\n", " \n", " for i in range(data_dimentions):\n", " ith_sorted_X = X[X[:,i].argsort(), :]\n", " ith_sorted_centers = centers[centers[:,i].argsort(), :]\n", " idx_center = 1\n", " cnt_mistakes = count_mistakes(X,l,i,labels,centers)\n", " for j,x in enumerate(ith_sorted_X[:-1]):\n", " if(l[i]>x[i] or x[i]>=r[i]):\n", " continue\n", " \n", " cnt_mistakes = count_mistakes(X,x,i,labels,centers) #ここで本来はDPでより効率よく計算すべきだが,やり方がよくわからない.なのでナイーブなやり方でやっている.つまり,全データに対してその分割でmistakeとなるのか否かを調べている\n", " \n", " if bests_split['mistake'] > cnt_mistakes:\n", " bests_split['mistake'] = cnt_mistakes\n", " bests_split['coordinate'] = i\n", " bests_split['threshold'] = x[i] \n", " print(\"num of mistakes at this node => {}\".format(bests_split['mistake']))\n", " return bests_split['coordinate'],bests_split['threshold']\n", "\n", "def build_tree(X,labels,centers):\n", " node = TreeNode()\n", " l=[]\n", " r=[]\n", " \n", " #論文疑似コード 2〜4行目\n", " if(len(np.unique(labels))==1):\n", " node.cluster = labels[0]\n", " return node\n", "\n", " #論文疑似コード 6〜9行目\n", " for i in range(X.shape[1]):\n", " l.append(minimum_center(i,labels,centers))\n", " r.append(maximum_center(i,labels,centers))\n", "\n", " #論文疑似コード 10〜13行目\n", " i,threshold = get_best_splits(X,l,r,labels,centers)\n", " X,labels = delete_mistakes_data(X,labels,centers,i,threshold)\n", " left_data,left_labels,right_data,right_labels = make_next_data(X,labels,i,threshold)\n", " \n", " #論文疑似コード 14〜16行目\n", " node.condition = (i,threshold)\n", " node.left = build_tree(left_data,left_labels,centers)\n", " node.right = build_tree(right_data,right_labels,centers)\n", " \n", " return node" ], "execution_count": 83, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "viQyjZlK6ryE", "colab_type": "text" }, "source": [ "

\n", "## First data and k=3" ] }, { "cell_type": "code", "metadata": { "id": "lWw2YW_w6jds", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 531 }, "outputId": "a344e197-c168-42dd-9157-657548eb4ad1" }, "source": [ "#IMM procedure\n", "kmeans_model = KMeans(n_clusters=3, random_state=10).fit(X)\n", "centers = kmeans_model.cluster_centers_\n", "labels = kmeans_model.labels_\n", "root = build_tree(X,labels,centers)\n", "make_tree(root,kmeans_model.n_clusters)" ], "execution_count": 87, "outputs": [ { "output_type": "stream", "text": [ "num of mistakes at this node => 1\n", "num of mistakes at this node => 2\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n0\n\nX_0 <= 75\n\n\n\n1\n\n0\n\n\n\n0->1\n\n\nTrue\n\n\n\n2\n\nX_1 <= 82\n\n\n\n0->2\n\n\nFalse\n\n\n\n3\n\n2\n\n\n\n2->3\n\n\nTrue\n\n\n\n4\n\n1\n\n\n\n2->4\n\n\nFalse\n\n\n\n" }, "metadata": { "tags": [] }, "execution_count": 87 } ] }, { "cell_type": "markdown", "metadata": { "id": "cYEF7HEWKTbj", "colab_type": "text" }, "source": [ "⇩クラスタ0におけるmistakeは1つだけ" ] }, { "cell_type": "code", "metadata": { "id": "f9AXdugNJ95V", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "cdd04e03-18ef-4669-9b68-07df7f0b2bcc" }, "source": [ "labels[X[:,0] <= 75]" ], "execution_count": 88, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 88 } ] }, { "cell_type": "markdown", "metadata": { "id": "PIdjQh4JK4n1", "colab_type": "text" }, "source": [ "⇩クラスタ2におけるmistakeはない(0は考えなくて良い)." ] }, { "cell_type": "code", "metadata": { "id": "MITFO8KkJ6ZA", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "7eeb9fe7-7228-4f52-f5f4-718c5a854fd6" }, "source": [ "labels[X[:,1] <= 82]" ], "execution_count": 89, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([2, 2, 2, 0, 2, 0, 0, 0, 2, 0, 2, 2, 2, 2, 2, 2], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 89 } ] }, { "cell_type": "markdown", "metadata": { "id": "DVWlMj5ELCRt", "colab_type": "text" }, "source": [ "⇩クラスタ1におけるmistakeは2つ" ] }, { "cell_type": "code", "metadata": { "id": "Kx0rArMgKNlU", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "outputId": "341c3bd8-722e-4f40-8bc9-554f5a4cb7c3" }, "source": [ "labels[X[:,1] > 82]" ], "execution_count": 90, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 2, 1, 0, 0, 2, 1, 1, 1, 0],\n", " dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 90 } ] }, { "cell_type": "markdown", "metadata": { "id": "F_qNy2J16uEP", "colab_type": "text" }, "source": [ "

\n", "## Second data and k=3" ] }, { "cell_type": "code", "metadata": { "id": "0ibTR82J6liX", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 595 }, "outputId": "a9588cb0-051f-4e41-9d25-f4ed76b38b40" }, "source": [ "#IMM procedure\n", "kmeans_model = KMeans(n_clusters=3, random_state=10).fit(cust_array)\n", "centers = kmeans_model.cluster_centers_\n", "labels = kmeans_model.labels_\n", "\n", "root = build_tree(cust_array,labels,centers)\n", "make_tree(root,kmeans_model.n_clusters)" ], "execution_count": 97, "outputs": [ { "output_type": "stream", "text": [ "num of mistakes at this node => 8\n", "num of mistakes at this node => 11\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "" ], "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\n0\n\nX_0 > 20049\n\n\n\n1\n\n0\n\n\n\n0->1\n\n\nTrue\n\n\n\n2\n\nX_1 <= 12220\n\n\n\n0->2\n\n\nFalse\n\n\n\n3\n\n2\n\n\n\n2->3\n\n\nTrue\n\n\n\n4\n\n1\n\n\n\n2->4\n\n\nFalse\n\n\n\n" }, "metadata": { "tags": [] }, "execution_count": 97 } ] }, { "cell_type": "code", "metadata": { "id": "glKXXfuVLXV5", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "outputId": "47fa3082-32db-4ffe-9d14-409881cb5963" }, "source": [ "labels[cust_array[:,0] > 20049]" ], "execution_count": 93, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 2, 0, 1, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 93 } ] }, { "cell_type": "code", "metadata": { "id": "wpL565elMakS", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 340 }, "outputId": "91e0e6ee-7cb8-48ee-b1dc-38f5220fc7ac" }, "source": [ "labels[cust_array[:,1] <= 12220]" ], "execution_count": 94, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 0,\n", " 0, 2, 2, 2, 0, 2, 2, 0, 0, 2, 2, 0, 2, 0, 0, 2, 2, 1, 2, 2, 2, 2,\n", " 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 0, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0,\n", " 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 1, 2, 2, 2,\n", " 0, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2,\n", " 2, 2, 0, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2,\n", " 2, 0, 2, 2, 0, 2, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 0, 2, 2, 2, 2,\n", " 0, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 0, 2, 2, 2, 2, 2, 0, 1, 2, 0, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2,\n", " 2, 0, 2, 2, 1, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2,\n", " 0, 0, 2, 2, 2, 2, 2, 0, 2, 2, 0, 2, 0, 2, 2, 2, 0, 2, 2, 2, 2, 2,\n", " 0, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 0, 0, 2, 2],\n", " dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 94 } ] }, { "cell_type": "code", "metadata": { "id": "ScazzrEdMf2B", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "outputId": "98c91de4-3efc-4435-8f2e-31819d6180c9" }, "source": [ "labels[cust_array[:,1] > 12220]" ], "execution_count": 95, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1,\n", " 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 0, 1, 1, 2, 1, 1, 0, 1],\n", " dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 95 } ] }, { "cell_type": "markdown", "metadata": { "id": "pM5V20ceM3F6", "colab_type": "text" }, "source": [ "

\n", "func for visualization \n", "sorry for dirty code..." ] }, { "cell_type": "code", "metadata": { "id": "QH-P2P-k_g8h", "colab_type": "code", "colab": {} }, "source": [ "def make_tree(root,n_clusters):\n", " G = Digraph(format='png')\n", " G.attr('node', shape='circle')\n", " N = 2*n_clusters - 1 #ノード数\n", " \n", " q = queue.Queue()\n", " q.put(root)\n", " if(root.right.cluster != None):\n", " G.node(str(0),\"X_{} > {}\".format(root.condition[0],root.condition[1]))\n", " else:\n", " G.node(str(0),\"X_{} <= {}\".format(root.condition[0],root.condition[1]))\n", " i=1\n", " while not q.empty():\n", " root = q.get()\n", "\n", " if root.left.cluster != None and root.right.cluster != None:\n", " G.node(str(i), str(root.left.cluster))\n", " G.edge(str(i-1), str(i),label='True')\n", " G.node(str(i+1), str(root.right.cluster))\n", " G.edge(str(i-1), str(i+1),label='False') \n", " elif root.right.cluster != None:\n", " G.node(str(i), str(root.right.cluster))\n", " G.edge(str(i-1), str(i),label='True')\n", " G.node(str(i+1),\"X_{} <= {}\".format(root.left.condition[0],root.left.condition[1]))\n", " G.edge(str(i-1), str(i+1),label='False')\n", " q.put(root.left)\n", " else:\n", " G.node(str(i), str(root.left.cluster))\n", " G.edge(str(i-1), str(i),label='True')\n", " G.node(str(i+1),\"X_{} <= {}\".format(root.right.condition[0],root.right.condition[1]))\n", " G.edge(str(i-1), str(i+1),label='False')\n", " q.put(root.right) \n", " i+=2\n", " return G" ], "execution_count": 85, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hqNdbgdr8Kuk", "colab_type": "text" }, "source": [ "

\n", "## Decision tree simulation" ] }, { "cell_type": "code", "metadata": { "id": "Aj1mdsRq8Rzu", "colab_type": "code", "colab": {} }, "source": [ "def make_toydata(v):\n", " mean1 = np.array([2, 0])\n", " cov1 = np.array([[0.3, 0], [0, 0.3]])\n", " data_1 = np.random.multivariate_normal(mean1, cov1, size=200)\n", "\n", " mean2 = np.array([-2, 0])\n", " cov2 = np.array([[0.3, 0], [0, 0.3]])\n", " data_2 = np.random.multivariate_normal(mean2, cov2, size=200)\n", "\n", " data_3 = np.array([[-2,v],[2,v]])\n", "\n", " return data_1,data_2,data_3" ], "execution_count": 62, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "W1j_ehJgADif", "colab_type": "code", "colab": {} }, "source": [ "data_1,data_2,data_3 = make_toydata(v=100)" ], "execution_count": 63, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "jJNCN64BALIP", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 282 }, "outputId": "8ab3dba7-d3bc-43a4-87d7-6806b436f9f2" }, "source": [ "plt.scatter(data_1[:,0],data_1[:,1])\n", "plt.scatter(data_2[:,0],data_2[:,1])" ], "execution_count": 64, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 64 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "6wSIHB1S_xMJ", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 282 }, "outputId": "df77f67e-8e47-42e3-cb50-5732b109f28b" }, "source": [ "plt.scatter(data_1[:,0],data_1[:,1])\n", "plt.scatter(data_2[:,0],data_2[:,1])\n", "plt.scatter(data_3[:,0],data_3[:,1])" ], "execution_count": 65, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 65 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAS+0lEQVR4nO3df5DcdX3H8ef7Nhu5RM0PucGYxAaU0aIi4A3+oONYUUSxEq1FGKdFywzjVCtqq4B2JNo64tBBsJ06pWKNM4w1AwgoWqSItdZCvQDyG6UokhjgFIL8CHJJ3v1jvxf2Lrt3t7t32csnz8fMze338/1+dt+33/2+9nuf7/e7G5mJJKksA/0uQJI0+wx3SSqQ4S5JBTLcJalAhrskFWhBvwsA2H///XPNmjX9LkOS9iobN278dWYOtZo3L8J9zZo1jIyM9LsMSdqrRMS97eY5LCNJBTLcJalAhrskFchwl6QCGe6SVKBpwz0ivhwRD0bErU1tyyPi6oj4WfV7WdUeEfGFiLg7Im6OiCPmsnjt7sp7ruSYi4/h0PWHcszFx3DlPVf2uySpLV+vc2cme+5fAY6d1HYGcE1mHgxcU00DvBk4uPo5Ffji7JSpmbjynitZ96N1bHl8C0my5fEtrPvROjcYzUu+XufWtOGemT8AHprUfDywvrq9Hljb1P7VbLgOWBoRK2arWE3t/BvO58kdT05oe3LHk5x/w/l9qkhqz9fr3Op2zP2AzNxS3b4fOKC6vRK4r2m5TVXbbiLi1IgYiYiR0dHRLstQs/sfv7+jdqmffL3OrZ4PqGbj2z46/saPzLwgM4czc3hoqOXVs+rQcxc/t6N2qZ98vc6tbsP9gfHhlur3g1X7ZmB103KrqjbtAacdcRr71fab0LZfbT9OO+K0PlUktefrdW51G+5XACdXt08GLm9q/7PqrJlXAY80Dd9ojh130HGse806VixeQRCsWLyCda9Zx3EHHdfv0qTd+HqdWzHdd6hGxNeA1wH7Aw8AZwGXARuA5wP3Aidk5kMREcA/0ji75gngvZk57SeCDQ8Ppx8cJkmdiYiNmTncat60nwqZmSe1mXV0i2UTeH9n5UmSZptXqEpSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQIa7JBXIcJekAhnuklQgw12SCmS4S1KBDHdJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUIMNdkgpkuEtSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QC9RTuEfHhiLgtIm6NiK9FxH4RcWBEXB8Rd0fE1yNi4WwVK0mama7DPSJWAh8EhjPzpUANOBH4HPD5zHwh8DBwymwUKkmauV6HZRYAgxGxAFgEbAFeD1xczV8PrO3xMSRJHeo63DNzM/D3wC9phPojwEZga2ZurxbbBKxs1T8iTo2IkYgYGR0d7bYMSVILvQzLLAOOBw4EngcsBo6daf/MvCAzhzNzeGhoqNsyJEkt9DIs8wbg55k5mpljwKXAUcDSapgGYBWwuccaJUkd6iXcfwm8KiIWRUQARwO3A9cC76yWORm4vLcSJUmd6mXM/XoaB05vAG6p7usC4HTgIxFxN/Ac4MJZqFOS1IEF0y/SXmaeBZw1qfke4Mhe7leS1BuvUJWkAhnuklQgw12SCmS4S1KBDHdJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUIMNdkgpkuEtSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQIa7JBXIcJekAhnuklQgw12SCmS4S1KBDHdJKpDhLkkF6incI2JpRFwcEXdGxB0R8eqIWB4RV0fEz6rfy2arWEnSzPS6534+8O+Z+WLg5cAdwBnANZl5MHBNNS1J2oO6DveIWAK8FrgQIDOfysytwPHA+mqx9cDaXouUJHWmlz33A4FR4F8j4saI+FJELAYOyMwt1TL3Awe06hwRp0bESESMjI6O9lCGJGmyXsJ9AXAE8MXMPBx4nElDMJmZQLbqnJkXZOZwZg4PDQ31UIYkabJewn0TsCkzr6+mL6YR9g9ExAqA6veDvZUoSepU1+GemfcD90XEi6qmo4HbgSuAk6u2k4HLe6pQktSxBT32/0vgoohYCNwDvJfGG8aGiDgFuBc4ocfHkCR1qKdwz8ybgOEWs47u5X4lSb3xClVJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUIMNdkgpkuEtSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQIa7JBXIcJekAhnuklQgw12SCmS4S1KBDHdJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUoJ7DPSJqEXFjRHyrmj4wIq6PiLsj4usRsbD3MiVJnZiNPffTgDuapj8HfD4zXwg8DJwyC48hSepAT+EeEauA44AvVdMBvB64uFpkPbC2l8eQJHWu1z3384CPATur6ecAWzNzezW9CVjZqmNEnBoRIxExMjo62mMZkqRmXYd7RLwVeDAzN3bTPzMvyMzhzBweGhrqtgxJUgsLeuh7FPC2iHgLsB/wbOB8YGlELKj23lcBm3svU5LUia733DPzzMxclZlrgBOB72Xmu4FrgXdWi50MXN5zlZKkjszFee6nAx+JiLtpjMFfOAePIUmaQi/DMrtk5veB71e37wGOnI37lSR1xytUJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQIa7JBXIcJekAhnuklQgw12SCmS4S1KBDHdJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUIMNdkgpkuEtSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoG6DveIWB0R10bE7RFxW0ScVrUvj4irI+Jn1e9ls1euJGkmetlz3w78VWYeArwKeH9EHAKcAVyTmQcD11TTkqQ9qOtwz8wtmXlDdftR4A5gJXA8sL5abD2wttciJUmdmZUx94hYAxwOXA8ckJlbqln3Awe06XNqRIxExMjo6OhslCFJqvQc7hHxTOAS4EOZ+dvmeZmZQLbql5kXZOZwZg4PDQ31WoYkqUlP4R4RdRrBflFmXlo1PxARK6r5K4AHeytRktSpXs6WCeBC4I7MPLdp1hXAydXtk4HLuy9PktSNBT30PQr4U+CWiLipavs4cDawISJOAe4FTuitRElSp7oO98z8IRBtZh/d7f1KknrnFaqSVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQIa7JBXIcJekAhnuklQgw12SCmS4S1KBDHdJKpDhLkkFMtwlqUCGuyQVyHCXpAIZ7pJUIMNdkgpkuEtSgQx3SSqQ4S5JBTLcJalAhrskFchwl6QCGe6SVCDDXZIKZLhLUoEW9LsATXLzBrjm0/DIJliyCo7+JBx6wtPzv/UR2PgVyB0QNXjFe+Ct53Z/f9prXXbjZs656i5+tXUbz1s6yEff9CLWHr6y32V1ZE/9DSU8V52KzOx3DQwPD+fIyEi/y+hcq+CEp9sGlzWmtz28e7A299213ENAAM3rpJpeshqWHwQ//8/d6xg+pRHwu+7zvhb308aS1U/XNfnvOfgY+Nl32/99vln0zWU3bubMS29h29iOXW2D9RqffcfLWoZWq3ADpgy86QKxef6SwToRsPWJMZYM1hnbsZPHn3q6tqWDdda97SW79Z/qb7jsxs18/NKbeWJsJwAR8JqDlvOL32xj89ZtRMB4fC1eWKNeG2DrtrEJr/xli+ocsuJZ/Oj/Hmq1VbFy0t81/jdt3rqNWgQ7MnctM93zNd36mos3l4jYmJnDLefNRbhHxLHA+UAN+FJmnj3V8vM63G/eAN85vQpeIAYgd0J9MYw9PnHZ2sLGq23nWOv7qi2Egfru/UrV/MYB078ZdvtmsQ/+d3LU2d9j89Ztu7UvW1Rn0cIFu0LkD188xDdu2DwhaKezeGGNZz6jxgOPPrXbvEX1ARYuqO0WoiVYOljn0d9tZ8fOmf1V9VpwzjtfPm1I/81lt3DRdb+c8FwN1mv88StWcu2doz0F/h4N94ioAT8F3ghsAn4MnJSZt7fr01W4d7NBd9rn5g1w2V+0D2vNzOByeMnb4Yavzuy5HFwOb/5c+3Uz+b+epx6DHU1BVB+EP/pC0QF/4BlXFhWse6uBgFcftJz/uechxt8TBusDfPYdh+767+PDX79pRutqqv+82tnT4f5qYF1mvqmaPhMgMz/brk/H4X7zBvjmB2Gsac+l1QbdbQhMGN5QX0y1biav+1aWrIYP3zp39fWo12GSgWrIQPPXskV1Hn6isx3DlUsH+e8zXj/j5acK97k4oLoSaE7FTcArZ/URrvn07hv32LZGe/MQQHMIjA+rTNWnVT/1R6t1A63XfSuPbJqbumbB5LHmzVu38dGLfwIJY9Xu3+at2zjz0lsAdu0BNvcx2Oe/ToMd4Fcthtq61bdTISPi1IgYiYiR0dHRzjq323Cb27sNgZn209xrtZ5nGtpLVs1uLbPonKvumnAQEWBsR+4K9nHbxnZwzlV3te0DUIsg5q5U7WHPWzo4a/c1F+G+GVjdNL2qapsgMy/IzOHMHB4aGursEdptuM3t3YbAPN7j2+e0Ws8zCe364NMHa+ehTvbOxpdt12dnJj8/+zhqYcTv7QJ2Dc/NhrkI9x8DB0fEgRGxEDgRuGJWH+HoTzY24GaTN+huQ2Ae7/HtU9oFdKt1P1BvHIQlGmPt8/xgaid7Z+PLtusz3u4wzd5h6WCdgTbvw+9+1fNn9dz7WQ/3zNwOfAC4CrgD2JCZt83qgxx6QmMDXrKatht0tyHQ7o3jHf/S+Gl+zOFTqulZFvv4hcODy9sHdKt1v/af4PSfw7qtjYOo8zjYobF3NlivTWir14L6pK1+sF7btSfXqk/z/JVtwr8+QNfDNrV2KbQPOuoFy/nF2cfxi7OP47x3HcbKpYMEjef9qBcsnxDY9YHG+mw2WK9x3rsO46azjuHcEw5j6WB917xli+qc967D+Lu1L5vVmsu+iKnb85877ff5l7Y+syZqjXPia/WJZ+k0Zjauysidra80nXwlatRg5+7nHc+ayacfNj8H9UUw9gQTzmpeuBhqz2gcqB4/93+q+4WJ1ws0t+9j56jD7F9UNJOLmtpddNTuHOvm5ZcuqvPk2A62je2+nuu1YPHCBWzdNsZAMOGUwJ0Jv9v+dJ/6AAwMDExomyvjF04BfOqbt004wNnuoqrxC5jG1SI46ZWrOw7ePXVF7B6/iKlT8/oippmYyamZvV5oM9VjwPRXyv7u0dbnmE93Trn2GnvTpfyT3zgy4ZFtY7u90bW6UnSmV+CW/vECYLjvGXviKsleHmMfvIpTKp3hLkkFmirc9/Ejd5JUJsNdkgpkuEtSgQx3SSqQ4S5JBZoXZ8tExChwb7/rmIH9gV/3u4gZsta5Ya1zw1q783uZ2fLDueZFuO8tImKk3WlH8421zg1rnRvWOvsclpGkAhnuklQgw70zF/S7gA5Y69yw1rlhrbPMMXdJKpB77pJUIMNdkgpkuHcoIv42Im6OiJsi4rsR8bx+19RKRJwTEXdWtX4jIpb2u6Z2IuJPIuK2iNgZEfPyFLOIODYi7oqIuyPijH7XM5WI+HJEPBgRt/a7lqlExOqIuDYibq/W/2n9rqmdiNgvIv43In5S1fqpftc0HcfcOxQRz87M31a3Pwgckpnv63NZu4mIY4DvZeb2iPgcQGae3ueyWoqI3wd2Av8M/HVmzqvPf46IGvBT4I3AJhrfE3xSZt7e18LaiIjXAo8BX83Ml/a7nnYiYgWwIjNviIhnARuBtfPxeY2IABZn5mMRUQd+CJyWmdf1ubS23HPv0HiwVxYz4bvn5o/M/G71fbYA1wHz9pu/M/OOzLyr33VM4Ujg7sy8JzOfAv4NOL7PNbWVmT8AHup3HdPJzC2ZeUN1+1Ea37k8L78+KRseqybr1c+83PbHGe5diIjPRMR9wLuBT/a7nhn4c+A7/S5iL7YSaP6S3E3M0xDaW0XEGuBw4Pr+VtJeRNQi4ibgQeDqzJy3tYLh3lJE/EdE3Nri53iAzPxEZq4GLgI+MF/rrJb5BLC9qrVvZlKr9k0R8UzgEuBDk/4znlcyc0dmHkbjv+AjI2LeDnkBLOh3AfNRZr5hhoteBHwbOGsOy2lrujoj4j3AW4Gjs88HVzp4TuejzcDqpulVVZt6VI1fXwJclJmX9ruemcjMrRFxLXAsMG8PWrvn3qGIOLhp8njgzn7VMpWIOBb4GPC2zHyi3/Xs5X4MHBwRB0bEQuBE4Io+17TXqw5SXgjckZnn9rueqUTE0PgZZxExSOPg+rzc9sd5tkyHIuIS4EU0zu64F3hfZs67vbiIuBt4BvCbqum6+XhWD0BEvB34B2AI2ArclJlv6m9VE0XEW4DzgBrw5cz8TJ9Laisivga8jsZH0z4AnJWZF/a1qBYi4g+A/wJuobE9AXw8M7/dv6pai4hDgfU01v8AsCEzP93fqqZmuEtSgRyWkaQCGe6SVCDDXZIKZLhLUoEMd0kqkOEuSQUy3CWpQP8PvPIMSZIPNZMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "7ZF6c_Al8e7r", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 340 }, "outputId": "6b8e296d-7799-4e99-c280-380005d6cbf4" }, "source": [ "toy_X = np.concatenate([data_1,data_2,data_3])\n", "kmeans_model = KMeans(n_clusters=3, random_state=10).fit(toy_X)\n", "centers = kmeans_model.cluster_centers_\n", "labels = kmeans_model.labels_\n", "labels" ], "execution_count": 66, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 1, 1], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 66 } ] }, { "cell_type": "markdown", "metadata": { "id": "2hb3F0UP8nZz", "colab_type": "text" }, "source": [ "0クラスタが200データ,1クラスタが2データ,2クラスタが200データと狙い通りにクラスタリングできている" ] }, { "cell_type": "code", "metadata": { "id": "sWGe7Z8886Oq", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 340 }, "outputId": "b90376c5-9ce1-4645-f9ec-20c404ef53c3" }, "source": [ "dt = DecisionTreeClassifier(criterion='entropy',max_leaf_nodes=3).fit(toy_X,labels)\n", "dt.predict(toy_X)" ], "execution_count": 67, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 1, 0], dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 67 } ] }, { "cell_type": "markdown", "metadata": { "id": "SF_GVSxy_axZ", "colab_type": "text" }, "source": [ "やはり最初の分割が論文中のような分割になっている(data_3の片方が右へ,片方が左へ)" ] }, { "cell_type": "code", "metadata": { "id": "MFR5mVET78Wx", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 380 }, "outputId": "a5cd580f-9236-4136-bd11-c391c571c974" }, "source": [ "dot_data = export_graphviz(\n", " dt,\n", " filled=True,\n", " )\n", "graph = graphviz.Source(dot_data)\n", "graph" ], "execution_count": 68, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ], "image/svg+xml": "\n\n\n\n\n\nTree\n\n\n\n0\n\nX[0] <= -0.053\nentropy = 1.04\nsamples = 402\nvalue = [200, 2, 200]\n\n\n\n1\n\nX[1] <= 50.724\nentropy = 0.045\nsamples = 201\nvalue = [0, 1, 200]\n\n\n\n0->1\n\n\nTrue\n\n\n\n2\n\nentropy = 0.045\nsamples = 201\nvalue = [200, 1, 0]\n\n\n\n0->2\n\n\nFalse\n\n\n\n3\n\nentropy = 0.0\nsamples = 200\nvalue = [0, 0, 200]\n\n\n\n1->3\n\n\n\n\n\n4\n\nentropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\n\n\n\n1->4\n\n\n\n\n\n" }, "metadata": { "tags": [] }, "execution_count": 68 } ] }, { "cell_type": "markdown", "metadata": { "id": "XWQjjP1NDwto", "colab_type": "text" }, "source": [ "

\n", "## v→∞へ" ] }, { "cell_type": "code", "metadata": { "id": "MxRUCylwAa7o", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 187 }, "outputId": "3373645f-21c7-4a6e-9efa-fecb4797cdab" }, "source": [ "for i in range(1,10000,1000):\n", " data_1,data_2,data_3 = make_toydata(v=i)\n", " toy_X = np.concatenate([data_1,data_2,data_3])\n", " kmeans_model = KMeans(n_clusters=3, random_state=10).fit(toy_X)\n", " kmeans_labels = kmeans_model.labels_\n", " dt = DecisionTreeClassifier(criterion='entropy',max_leaf_nodes=3).fit(toy_X,kmeans_labels)\n", "\n", " cost = 0\n", " labels = dt.predict(toy_X)\n", " centers = get_mean(toy_X,labels)\n", " for i,data in enumerate(toy_X):\n", " cost += np.sum((data-kmeans_centers[labels[i]])*(data-kmeans_centers[labels[i]]))\n", " \n", " print(\"Optimal Score:{} vs Decision Tree Score:{}\".format(kmeans_model.score(toy_X),cost))" ], "execution_count": 73, "outputs": [ { "output_type": "stream", "text": [ "Optimal Score:-199.17982861954835 vs Decision Tree Score:16284591918.10414\n", "Optimal Score:-246.18124015152924 vs Decision Tree Score:65002282.27015967\n", "Optimal Score:-246.4869545232153 vs Decision Tree Score:53004322.696862355\n", "Optimal Score:-219.56312850576407 vs Decision Tree Score:45006333.642052606\n", "Optimal Score:-260.7027891214368 vs Decision Tree Score:41008413.387505375\n", "Optimal Score:-244.1077337078409 vs Decision Tree Score:41010435.60921242\n", "Optimal Score:-267.98712249271443 vs Decision Tree Score:45012495.40759582\n", "Optimal Score:-254.49826816169144 vs Decision Tree Score:53014522.64540449\n", "Optimal Score:-272.9091436380789 vs Decision Tree Score:65016576.71783102\n", "Optimal Score:-239.80939207439016 vs Decision Tree Score:81024105.0891669\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "wf7Og4kEBaXT", "colab_type": "text" }, "source": [ "v→無限大につれて,近似スコアは無限大になってしまう(もちろん,葉をクラスタ数にしなければこれは起こらないが) \n", "\\#sklearnのkmeansは目的関数のスコアのマイナスをreturnする" ] } ] }