{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "実例で動かすグラフアルゴリズムとグラフデータベース、04_ノードEmbedding:グラフ機械学習へ\n",
    "\n",
    "[ファイルはこちら](https://raw.githubusercontent.com/lightondust/topics_by_jupyter_notebook/master/neo4j_graph_04_embedding.ipynb)。\n",
    "\n",
    "今回使うNode Embedding LibraryはNeo4j Sandboxには搭載されておらず、実行するには無料のNeo4j DesktopかComunity版をインストールする必要があります。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Node Embeddingについて"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "今まではPageRankやLouvainアルゴリズムを使って重要なノードを見つけたり、つながりの強いグループを見つけたりしました。\n",
    "\n",
    "一方でノードやリレーションシップの特徴量を抽出できれば、従来の機械学習手法を適応しやすくなります。\n",
    "例えばアカウントをさまざまな視点で分類したり、属するグループだけではなくアカウント間の関係を数値化したり、グループ間の関係を数値化したり、可視化したりするなど行うことができます。\n",
    "\n",
    "特別に他の特徴量を用意しなくても繋がりさえあれば全体から見た場合の特徴量を使うことができるのはカラム志向に対してグラフ志向でデータを捉える利点の1つだと思います。\n",
    "\n",
    "Node Embeddingはノードの特徴を表すベクトルを算出します。\n",
    "NLPで有名なWord2Vecも一種のNode Embeddingとして捉えることもできます。\n",
    "基本的な考え方としてはノード間のリンクを再現できるように各ノードにベクトルを割り当てます。\n",
    "\n",
    "今回はNeo4j Data Science LibraryのNode Embeddingについて見てみます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neo3j Graph Data Science LibraryとNode Embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Neo4j Graph Data Science Libraryはα版でNode Embeddingアルゴリズムを提供しています。\n",
    "\n",
    "https://neo4j.com/docs/graph-data-science/current/algorithms/node-embeddings/\n",
    "\n",
    "アルゴリズムの種類としてはWord2Vec流のNode2Vec、ノード属性を活用できるグラフニューラルネットワーク流のGraphSage、そして高速のRandom Projectionが用意されています。\n",
    "\n",
    "https://neo4j.com/developer/graph-data-science/graph-embeddings/#supported-graph-embeddings\n",
    "\n",
    "ここではNode2Vecを使います。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Node Embeddingの算出 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "グラフは前回まで使っていた政治家のみのネットワークを流用します。\n",
    "作成する場合は次のように行います。\n",
    "\n",
    "```\n",
    "CALL gds.graph.create.cypher( \n",
    "    'follow-net-politicians', \n",
    "    'MATCH (g:Group)--(u:User) WITH DISTINCT u RETURN id(u) AS id', \n",
    "    'MATCH (:Group)--(u:User)-[r:FOLLOW]->(u2:User)--(:Group) WITH DISTINCT r, u, u2 RETURN id(u) AS source, id(u2) AS target')\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "10次元(node2vec)と3次元(node2vec3d)、4次元(node2vec4d)のEmbeddingをそれぞれ計算して、属性として保存します。\n",
    "次元数の選び方に関しては、一般的なWord2Vecは数十万〜数百万の単語(ノード)を100~300次元へ落とし込むのが1つの目安になります。\n",
    "今回は300程度のアカウントなので、あまり高次元にすると過学習が起きやすくなってしまいます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "CALL gds.alpha.node2vec.write('follow-net-politicians', \n",
    "{embeddingSize: 10,\n",
    "writeProperty: 'node2vec'})\n",
    "```\n",
    "\n",
    "```\n",
    "CALL gds.alpha.node2vec.write('follow-net-politicians', \n",
    "{embeddingSize: 3,\n",
    "writeProperty: 'node2vec3d'})\n",
    "```\n",
    "\n",
    "```\n",
    "CALL gds.alpha.node2vec.write('follow-net-politicians', \n",
    "{embeddingSize: 4,\n",
    "writeProperty: 'node2vec4d'})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 検証 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作成したEmbeddingを見てみます"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from neo4j import GraphDatabase\n",
    "from tqdm.notebook import tqdm\n",
    "import json\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "auth_path = './data/neo4j_graph/auth.json'\n",
    "with open(auth_path, 'r') as f:\n",
    "    auth = json.load(f)\n",
    "\n",
    "# ローカルの場合は通常 uri: bolt(or neo4j)://localhost:7687, user: neo4j, pd: 設定したもの\n",
    "# サンドボックスの場合は作成画面から接続情報が見られます\n",
    "uri = 'neo4j://localhost:7687'\n",
    "driver = GraphDatabase.driver(uri=uri, auth=(auth['user'], auth['pd']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sandboxの場合はこんな感じ\n",
    "# uri = 'bolt://54.175.38.249:35275'\n",
    "# driver = GraphDatabase.driver(uri=uri, auth=('neo4j', 'spray-missile-sizing'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_no = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    res = session.run('''\n",
    "    MATCH (u:User)--(:Group)\n",
    "    WITH DISTINCT u\n",
    "    RETURN u.screenName as screen_name, u.name as name, u.louvainCommunityUndirected as community, u.node2vec3d as embedding\n",
    "    ''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>screen_name</th>\n",
       "      <th>name</th>\n",
       "      <th>community</th>\n",
       "      <th>embedding</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>tadamori_oshima</td>\n",
       "      <td>大島理森</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.20086176693439484, 0.9452930688858032, 1.1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>YAMASHITA_OK</td>\n",
       "      <td>山下たかし</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.3143966794013977, 0.4608628451824188, 1.31...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>andouhiroshi</td>\n",
       "      <td>あんどう裕(ひろし)衆議院議員(自民党 京都6区 )</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.03988641873002052, 0.4570828974246979, 1.2...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>jimin_koho</td>\n",
       "      <td>自民党広報</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.564670205116272, 0.5222573280334473, 1.403...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>Matsukawa_Rui</td>\n",
       "      <td>松川るい =自民党=</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.5284397602081299, 0.6824808716773987, 1.36...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       screen_name                        name  community  \\\n",
       "0  tadamori_oshima                        大島理森          3   \n",
       "1     YAMASHITA_OK                       山下たかし          3   \n",
       "2     andouhiroshi  あんどう裕(ひろし)衆議院議員(自民党 京都6区 )          3   \n",
       "3       jimin_koho                       自民党広報          3   \n",
       "4    Matsukawa_Rui                  松川るい =自民党=          3   \n",
       "\n",
       "                                           embedding  \n",
       "0  [-0.20086176693439484, 0.9452930688858032, 1.1...  \n",
       "1  [-0.3143966794013977, 0.4608628451824188, 1.31...  \n",
       "2  [-0.03988641873002052, 0.4570828974246979, 1.2...  \n",
       "3  [-0.564670205116272, 0.5222573280334473, 1.403...  \n",
       "4  [-0.5284397602081299, 0.6824808716773987, 1.36...  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df = pd.DataFrame([r.data() for r in res])\n",
    "res_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Node Embeddingは関係のの近いノードを近い方向に配置し、ノードの重要性(出現頻度など)を絶対値に反映しているので、可視化する時にすべてのベクトルの大きさを1に揃えたほうがわかりやすいです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def renorm_vector(v):\n",
    "    norm = np.sqrt(sum([c**2 for c in v]))\n",
    "    return np.array(v) / norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax_columns = ['x{}'.format(i) for i in range(1, dim_no+1)]\n",
    "\n",
    "# res_df[ax_columns] = res_df.embedding.apply(pd.Series)\n",
    "\n",
    "res_df[ax_columns] = res_df.embedding.apply(lambda x: pd.Series(renorm_vector(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>screen_name</th>\n",
       "      <th>name</th>\n",
       "      <th>community</th>\n",
       "      <th>embedding</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>tadamori_oshima</td>\n",
       "      <td>大島理森</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.20086176693439484, 0.9452930688858032, 1.1...</td>\n",
       "      <td>-0.132467</td>\n",
       "      <td>0.623415</td>\n",
       "      <td>0.770588</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>YAMASHITA_OK</td>\n",
       "      <td>山下たかし</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.3143966794013977, 0.4608628451824188, 1.31...</td>\n",
       "      <td>-0.220553</td>\n",
       "      <td>0.323301</td>\n",
       "      <td>0.920235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>andouhiroshi</td>\n",
       "      <td>あんどう裕(ひろし)衆議院議員(自民党 京都6区 )</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.03988641873002052, 0.4570828974246979, 1.2...</td>\n",
       "      <td>-0.030714</td>\n",
       "      <td>0.351969</td>\n",
       "      <td>0.935508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>jimin_koho</td>\n",
       "      <td>自民党広報</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.564670205116272, 0.5222573280334473, 1.403...</td>\n",
       "      <td>-0.352781</td>\n",
       "      <td>0.326283</td>\n",
       "      <td>0.876975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>Matsukawa_Rui</td>\n",
       "      <td>松川るい =自民党=</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.5284397602081299, 0.6824808716773987, 1.36...</td>\n",
       "      <td>-0.327500</td>\n",
       "      <td>0.422967</td>\n",
       "      <td>0.844892</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       screen_name                        name  community  \\\n",
       "0  tadamori_oshima                        大島理森          3   \n",
       "1     YAMASHITA_OK                       山下たかし          3   \n",
       "2     andouhiroshi  あんどう裕(ひろし)衆議院議員(自民党 京都6区 )          3   \n",
       "3       jimin_koho                       自民党広報          3   \n",
       "4    Matsukawa_Rui                  松川るい =自民党=          3   \n",
       "\n",
       "                                           embedding        x1        x2  \\\n",
       "0  [-0.20086176693439484, 0.9452930688858032, 1.1... -0.132467  0.623415   \n",
       "1  [-0.3143966794013977, 0.4608628451824188, 1.31... -0.220553  0.323301   \n",
       "2  [-0.03988641873002052, 0.4570828974246979, 1.2... -0.030714  0.351969   \n",
       "3  [-0.564670205116272, 0.5222573280334473, 1.403... -0.352781  0.326283   \n",
       "4  [-0.5284397602081299, 0.6824808716773987, 1.36... -0.327500  0.422967   \n",
       "\n",
       "         x3  \n",
       "0  0.770588  \n",
       "1  0.920235  \n",
       "2  0.935508  \n",
       "3  0.876975  \n",
       "4  0.844892  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ここで比較するために以前計算していたクラスタリングを取得しています。\n",
    "各クラスタに入っているノード数は次のようになります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "174    127\n",
       "3      107\n",
       "169     66\n",
       "120     29\n",
       "257     17\n",
       "Name: community, dtype: int64"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.community = res_df.community.astype('category')\n",
    "res_df.community.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.scatter_3d(res_df,  \n",
    "                    x=\"x1\",  \n",
    "                    y=\"x2\",   \n",
    "                    z='x3', \n",
    "                    text=\"name\",  \n",
    "                    color='community',  \n",
    "                    log_x=False,  \n",
    "                    size_max=60)\n",
    "\n",
    "fig.update_traces(textposition='top center')\n",
    "\n",
    "fig.update_layout(\n",
    "    height=800,\n",
    "    title_text=''\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plotを保存する場合\n",
    "html_path = './data/neo4j_graph/politician_03_3dembedding.html'\n",
    "with open(html_path, 'w') as f:\n",
    "    f.write(fig.to_html())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "おおむね自民党系と民主党系が左右に固まり、維新系が真ん中にある構図が変わりませんでした。\n",
    "一方でマイナークラスタはトップ3に埋もれる形になりました。\n",
    "これは元データに対して次元圧縮をしすぎたため、表現力が足りないためと思われます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# figf = px.data.gapminder().query(\"year==2007 and continent=='Americas'\")\n",
    "\n",
    "# fig = px.scatter(res_df,  \n",
    "#                  x=\"x1\", \n",
    "#                  y=\"x2\", \n",
    "#                  text=\"name\", \n",
    "#                  color='community', \n",
    "#                  log_x=False, \n",
    "#                  size_max=60)\n",
    "\n",
    "# fig.update_traces(textposition='top center')\n",
    "\n",
    "# fig.update_layout(\n",
    "#     height=800,\n",
    "#     title_text=''\n",
    "# )\n",
    "\n",
    "# fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "では4dでどうなるかをみてみます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_no = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    res = session.run('''\n",
    "    MATCH (u:User)--(:Group)\n",
    "    WITH DISTINCT u\n",
    "    RETURN u.screenName as screen_name, u.name as name, u.louvainCommunityUndirected as community, u.node2vec4d as embedding\n",
    "    ''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>screen_name</th>\n",
       "      <th>name</th>\n",
       "      <th>community</th>\n",
       "      <th>embedding</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>tadamori_oshima</td>\n",
       "      <td>大島理森</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9669536352157593, -0.4772112965583801, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>YAMASHITA_OK</td>\n",
       "      <td>山下たかし</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.944596529006958, -0.29907843470573425, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>andouhiroshi</td>\n",
       "      <td>あんどう裕(ひろし)衆議院議員(自民党 京都6区 )</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9364525675773621, -0.093632273375988, -0.8...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>jimin_koho</td>\n",
       "      <td>自民党広報</td>\n",
       "      <td>3</td>\n",
       "      <td>[-1.0234124660491943, -0.5828015208244324, -0....</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>Matsukawa_Rui</td>\n",
       "      <td>松川るい =自民党=</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9714489579200745, -0.48148995637893677, -0...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       screen_name                        name  community  \\\n",
       "0  tadamori_oshima                        大島理森          3   \n",
       "1     YAMASHITA_OK                       山下たかし          3   \n",
       "2     andouhiroshi  あんどう裕(ひろし)衆議院議員(自民党 京都6区 )          3   \n",
       "3       jimin_koho                       自民党広報          3   \n",
       "4    Matsukawa_Rui                  松川るい =自民党=          3   \n",
       "\n",
       "                                           embedding  \n",
       "0  [-0.9669536352157593, -0.4772112965583801, -0....  \n",
       "1  [-0.944596529006958, -0.29907843470573425, -0....  \n",
       "2  [-0.9364525675773621, -0.093632273375988, -0.8...  \n",
       "3  [-1.0234124660491943, -0.5828015208244324, -0....  \n",
       "4  [-0.9714489579200745, -0.48148995637893677, -0...  "
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df = pd.DataFrame([r.data() for r in res])\n",
    "res_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Node Embeddingは関係のの近いノードを近い方向に配置し、ノードの重要性(出現頻度など)を絶対値に反映しているので、可視化する時にすべてのベクトルの大きさを1に揃えたほうがわかりやすいです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def renorm_vector(v):\n",
    "    norm = np.sqrt(sum([c**2 for c in v]))\n",
    "    return np.array(v) / norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax_columns = ['x{}'.format(i) for i in range(1, dim_no+1)]\n",
    "\n",
    "# res_df[ax_columns] = res_df.embedding.apply(pd.Series)\n",
    "res_df[ax_columns] = res_df.embedding.apply(lambda x: pd.Series(renorm_vector(x)))\n",
    "\n",
    "res_df.community = res_df.community.astype('category')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>screen_name</th>\n",
       "      <th>name</th>\n",
       "      <th>community</th>\n",
       "      <th>embedding</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>tadamori_oshima</td>\n",
       "      <td>大島理森</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9669536352157593, -0.4772112965583801, -0....</td>\n",
       "      <td>-0.626852</td>\n",
       "      <td>-0.309364</td>\n",
       "      <td>-0.545948</td>\n",
       "      <td>-0.461835</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>YAMASHITA_OK</td>\n",
       "      <td>山下たかし</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.944596529006958, -0.29907843470573425, -0....</td>\n",
       "      <td>-0.678253</td>\n",
       "      <td>-0.214749</td>\n",
       "      <td>-0.628372</td>\n",
       "      <td>-0.314651</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>andouhiroshi</td>\n",
       "      <td>あんどう裕(ひろし)衆議院議員(自民党 京都6区 )</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9364525675773621, -0.093632273375988, -0.8...</td>\n",
       "      <td>-0.736264</td>\n",
       "      <td>-0.073616</td>\n",
       "      <td>-0.657714</td>\n",
       "      <td>-0.141094</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>jimin_koho</td>\n",
       "      <td>自民党広報</td>\n",
       "      <td>3</td>\n",
       "      <td>[-1.0234124660491943, -0.5828015208244324, -0....</td>\n",
       "      <td>-0.691638</td>\n",
       "      <td>-0.393866</td>\n",
       "      <td>-0.530964</td>\n",
       "      <td>-0.290831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>Matsukawa_Rui</td>\n",
       "      <td>松川るい =自民党=</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.9714489579200745, -0.48148995637893677, -0...</td>\n",
       "      <td>-0.671192</td>\n",
       "      <td>-0.332670</td>\n",
       "      <td>-0.580522</td>\n",
       "      <td>-0.319101</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       screen_name                        name community  \\\n",
       "0  tadamori_oshima                        大島理森         3   \n",
       "1     YAMASHITA_OK                       山下たかし         3   \n",
       "2     andouhiroshi  あんどう裕(ひろし)衆議院議員(自民党 京都6区 )         3   \n",
       "3       jimin_koho                       自民党広報         3   \n",
       "4    Matsukawa_Rui                  松川るい =自民党=         3   \n",
       "\n",
       "                                           embedding        x1        x2  \\\n",
       "0  [-0.9669536352157593, -0.4772112965583801, -0.... -0.626852 -0.309364   \n",
       "1  [-0.944596529006958, -0.29907843470573425, -0.... -0.678253 -0.214749   \n",
       "2  [-0.9364525675773621, -0.093632273375988, -0.8... -0.736264 -0.073616   \n",
       "3  [-1.0234124660491943, -0.5828015208244324, -0.... -0.691638 -0.393866   \n",
       "4  [-0.9714489579200745, -0.48148995637893677, -0... -0.671192 -0.332670   \n",
       "\n",
       "         x3        x4  \n",
       "0 -0.545948 -0.461835  \n",
       "1 -0.628372 -0.314651  \n",
       "2 -0.657714 -0.141094  \n",
       "3 -0.530964 -0.290831  \n",
       "4 -0.580522 -0.319101  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.scatter_3d(res_df,  \n",
    "                    x=\"x1\",  \n",
    "                    y=\"x2\",   \n",
    "                    z='x4', \n",
    "                    text=\"name\",  \n",
    "                    color='community',  \n",
    "                    log_x=False,  \n",
    "                    size_max=60)\n",
    "\n",
    "fig.update_traces(textposition='top center')\n",
    "\n",
    "fig.update_layout(\n",
    "    height=800,\n",
    "    title_text=''\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "自民系は比較的に固まっていて、民主形はかなり広がりがあり、維新系は自民寄りになっていることが見て取れます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "10dのEmbeddingでクラスタリングしてみます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "with driver.session() as session:\n",
    "    res = session.run('''\n",
    "    MATCH (u:User)--(:Group)\n",
    "    WITH DISTINCT u\n",
    "    RETURN u.screenName as screen_name, u.name as name, u.louvainCommunityUndirected as community, u.node2vec as embedding\n",
    "    ''')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_df = pd.DataFrame([r.data() for r in res])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax_columns = ['x{}'.format(i) for i in range(1, 11)]\n",
    "\n",
    "# res_df[ax_columns] = res_df.embedding.apply(pd.Series)\n",
    "\n",
    "res_df[ax_columns] = res_df.embedding.apply(lambda x: pd.Series(renorm_vector(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>screen_name</th>\n",
       "      <th>name</th>\n",
       "      <th>community</th>\n",
       "      <th>embedding</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>x8</th>\n",
       "      <th>x9</th>\n",
       "      <th>x10</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>tadamori_oshima</td>\n",
       "      <td>大島理森</td>\n",
       "      <td>3</td>\n",
       "      <td>[0.8451248407363892, -0.3527873158454895, 0.79...</td>\n",
       "      <td>0.461382</td>\n",
       "      <td>-0.192599</td>\n",
       "      <td>0.433067</td>\n",
       "      <td>-0.204804</td>\n",
       "      <td>-0.220397</td>\n",
       "      <td>-0.109164</td>\n",
       "      <td>0.262750</td>\n",
       "      <td>0.550687</td>\n",
       "      <td>-0.237711</td>\n",
       "      <td>-0.176773</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>YAMASHITA_OK</td>\n",
       "      <td>山下たかし</td>\n",
       "      <td>3</td>\n",
       "      <td>[-0.21574953198432922, -0.9915562868118286, 0....</td>\n",
       "      <td>-0.103789</td>\n",
       "      <td>-0.477001</td>\n",
       "      <td>0.165211</td>\n",
       "      <td>-0.106283</td>\n",
       "      <td>-0.260988</td>\n",
       "      <td>-0.330459</td>\n",
       "      <td>0.587693</td>\n",
       "      <td>0.288013</td>\n",
       "      <td>-0.341218</td>\n",
       "      <td>0.032010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>andouhiroshi</td>\n",
       "      <td>あんどう裕(ひろし)衆議院議員(自民党 京都6区 )</td>\n",
       "      <td>3</td>\n",
       "      <td>[0.48249784111976624, -0.26339074969291687, -0...</td>\n",
       "      <td>0.340378</td>\n",
       "      <td>-0.185809</td>\n",
       "      <td>-0.154462</td>\n",
       "      <td>-0.271198</td>\n",
       "      <td>-0.215974</td>\n",
       "      <td>-0.281920</td>\n",
       "      <td>0.519558</td>\n",
       "      <td>0.533754</td>\n",
       "      <td>-0.059092</td>\n",
       "      <td>-0.260310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>jimin_koho</td>\n",
       "      <td>自民党広報</td>\n",
       "      <td>3</td>\n",
       "      <td>[0.44367820024490356, -0.663674533367157, 0.38...</td>\n",
       "      <td>0.236184</td>\n",
       "      <td>-0.353295</td>\n",
       "      <td>0.203981</td>\n",
       "      <td>-0.113193</td>\n",
       "      <td>-0.199367</td>\n",
       "      <td>-0.334902</td>\n",
       "      <td>0.628517</td>\n",
       "      <td>0.415378</td>\n",
       "      <td>-0.167158</td>\n",
       "      <td>-0.132505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>Matsukawa_Rui</td>\n",
       "      <td>松川るい =自民党=</td>\n",
       "      <td>3</td>\n",
       "      <td>[0.19990792870521545, -0.42390283942222595, 0....</td>\n",
       "      <td>0.115597</td>\n",
       "      <td>-0.245123</td>\n",
       "      <td>0.126030</td>\n",
       "      <td>-0.102652</td>\n",
       "      <td>-0.334126</td>\n",
       "      <td>-0.305738</td>\n",
       "      <td>0.538014</td>\n",
       "      <td>0.570700</td>\n",
       "      <td>-0.282578</td>\n",
       "      <td>0.002598</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       screen_name                        name  community  \\\n",
       "0  tadamori_oshima                        大島理森          3   \n",
       "1     YAMASHITA_OK                       山下たかし          3   \n",
       "2     andouhiroshi  あんどう裕(ひろし)衆議院議員(自民党 京都6区 )          3   \n",
       "3       jimin_koho                       自民党広報          3   \n",
       "4    Matsukawa_Rui                  松川るい =自民党=          3   \n",
       "\n",
       "                                           embedding        x1        x2  \\\n",
       "0  [0.8451248407363892, -0.3527873158454895, 0.79...  0.461382 -0.192599   \n",
       "1  [-0.21574953198432922, -0.9915562868118286, 0.... -0.103789 -0.477001   \n",
       "2  [0.48249784111976624, -0.26339074969291687, -0...  0.340378 -0.185809   \n",
       "3  [0.44367820024490356, -0.663674533367157, 0.38...  0.236184 -0.353295   \n",
       "4  [0.19990792870521545, -0.42390283942222595, 0....  0.115597 -0.245123   \n",
       "\n",
       "         x3        x4        x5        x6        x7        x8        x9  \\\n",
       "0  0.433067 -0.204804 -0.220397 -0.109164  0.262750  0.550687 -0.237711   \n",
       "1  0.165211 -0.106283 -0.260988 -0.330459  0.587693  0.288013 -0.341218   \n",
       "2 -0.154462 -0.271198 -0.215974 -0.281920  0.519558  0.533754 -0.059092   \n",
       "3  0.203981 -0.113193 -0.199367 -0.334902  0.628517  0.415378 -0.167158   \n",
       "4  0.126030 -0.102652 -0.334126 -0.305738  0.538014  0.570700 -0.282578   \n",
       "\n",
       "        x10  \n",
       "0 -0.176773  \n",
       "1  0.032010  \n",
       "2 -0.260310  \n",
       "3 -0.132505  \n",
       "4  0.002598  "
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "174    127\n",
       "3      107\n",
       "169     66\n",
       "120     29\n",
       "257     17\n",
       "Name: community, dtype: int64"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df.community = res_df.community.astype('category')\n",
    "res_df.community.value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Louvainと比較するためにクラス多数を同じ5とします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "k_m = KMeans(n_clusters=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(346, 10)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_arr = np.array([l for l in res_df.embedding])\n",
    "embedding_arr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "k_m.fit_predict(embedding_arr)\n",
    "res_df['k_community'] = res_df.embedding.apply(lambda x: k_m.predict(np.array(x).reshape(1, -1))[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 130)\n",
      "174    110\n",
      "257      9\n",
      "169      9\n",
      "3        2\n",
      "120      0\n",
      "Name: community, dtype: int64\n",
      "(1, 117)\n",
      "3      100\n",
      "169     12\n",
      "257      2\n",
      "174      2\n",
      "120      1\n",
      "Name: community, dtype: int64\n",
      "(2, 62)\n",
      "169    43\n",
      "174    12\n",
      "257     4\n",
      "120     2\n",
      "3       1\n",
      "Name: community, dtype: int64\n",
      "(3, 35)\n",
      "120    26\n",
      "3       4\n",
      "257     2\n",
      "169     2\n",
      "174     1\n",
      "Name: community, dtype: int64\n",
      "(4, 2)\n",
      "174    2\n",
      "257    0\n",
      "169    0\n",
      "120    0\n",
      "3      0\n",
      "Name: community, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "for i in res_df.k_community.value_counts().items():\n",
    "    print(i)\n",
    "    print(res_df[res_df.k_community == i[0]].community.value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "概ね自民、民主、維新系のまま、すこし趣の違う感じになりました。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### クラスタ数について"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Embeddingしてからクラスタリングを行う利点のひとつはクラスタ数を調整できることです。\n",
    "ここでは各クラス多数におけるフィット具合を見てみます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "572338fcc9074c2592626bd8856c13b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=28), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "models = []\n",
    "for i in tqdm(range(2, 30)):\n",
    "    k_m = KMeans(n_clusters=i)\n",
    "    k_m.fit(embedding_arr)\n",
    "    models.append(k_m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = [i for i in range(2, 30)]\n",
    "y = np.array([k.inertia_ for k in models])\n",
    "# plt.plot(x, y)\n",
    "# plt.plot([i for i in range(2, 15)], np.log(np.array([k.inertia_ for k in models])))\n",
    "\n",
    "px.line(x=x, y=y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "このグラフはこれは各点が自分が所属しているクラスタの中心までの距離の平均を表していて、クラスタ数が多くなれば距離が減りますが、その分粒度が細かくなります。\n",
    "クラスタ数を増やしても距離があまり減らなくなるところが一般的にいいとされています。\n",
    "この場合だと5と10になるでしょうか。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graph",
   "language": "python",
   "name": "graph"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc-autonumbering": true
 },
 "nbformat": 4,
 "nbformat_minor": 4
}