{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-26-deepwalk.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T384270%20%7C%20DeepWalk%20in%20python.ipynb","timestamp":1644674259116}],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyOzNKbYgMjjVsOkMk1OZIyv"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"l5uXqP2G7ov1"},"source":["# DeepWalk in python"]},{"cell_type":"markdown","metadata":{"id":"_Ue1g8397t0o"},"source":["## Imports"]},{"cell_type":"code","metadata":{"id":"V_CAT4tG7om_"},"source":["import random\n","import networkx as nx \n","from gensim.models import Word2Vec\n","\n","import numpy as np\n","from abc import ABC\n","import pandas as pd"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_GAkUKgmTkZR"},"source":["## Deepwalk"]},{"cell_type":"code","metadata":{"id":"6DW32rssTing"},"source":["class DeepWalk:\n"," \"\"\"\n"," Implement DeepWalk algorithm.\n"," reference paper : DeepWalk: Online Learning of Social Representations\n"," link : https://arxiv.org/abs/1403.6652\n"," Using the algorithm can get graph embedding model with your network data.\n"," \"\"\"\n"," def __init__(self, G=None, adjlist_path=None, edgelist_path=None):\n"," \"\"\"\n"," Parameters\n"," G : networkx : networkx graph.\n"," \n"," adjlist_path : network file path. \n"," \"\"\"\n"," if G == adjlist_path == edgelist_path == None:\n"," raise ValueError('all parameter is None, please check your input.')\n"," \n"," try:\n"," \n"," if G != None:\n"," self.G = G\n"," elif adjlist_path != None:\n"," self.G = nx.read_adjlist(adjlist_path)\n"," elif edgelist_path != None:\n"," self.G = nx.read_edgelist(edgelist_path)\n","\n"," except Exception as e:\n"," print(e)\n","\n","\n"," def random_walk(self, iterations, start_node=None, random_walk_times=5):\n"," \"\"\"\n"," : Implement of random walk algorithm :\n"," Parameters\n"," ----------------------------------------\n"," iterations : int : random walk number of iteration \n"," start_node : str : choose start node (random choose a node, if start_node is None)\n"," random_walk_times : int : random walk times.\n"," ----------------------------------------\n"," Returns\n"," walk_records : list of walks record\n"," \"\"\"\n"," walk_records = []\n","\n"," for i in range(iterations):\n"," \n"," if start_node is None:\n"," s_node = random.choice(list(self.G.nodes()))\n"," walk_path = [s_node]\n"," else:\n"," walk_path = [start_node]\n"," \n"," current_node = s_node\n"," while(len(walk_path) < random_walk_times):\n"," neighbors = list(self.G.neighbors(current_node))\n"," \n"," \n"," current_node = random.choice(neighbors)\n"," walk_path.append(current_node)\n"," \n"," walk_records.append(walk_path)\n"," \n"," return walk_records\n","\n"," def buildWord2Vec(self, **kwargs):\n"," \"\"\"\n"," \n"," Using gensim to build word2vec model\n"," Parameters\n"," ----------------------------------------\n"," **kwargs\n"," \n"," walk_path : list : random walk results\n"," size : int : specific embedding dimension, default : 100 dim\n"," window : int : specific learn context window size, default : 5\n"," workers : int : specific workers. default : 2\n"," ----------------------------------------\n"," Returns\n"," walk_records : list of walks record\n"," \"\"\"\n"," \n"," walk_path = kwargs.get('walk_path', None)\n"," if walk_path is None:\n"," return \n"," \n"," size = kwargs.get('size', 100)\n"," window = kwargs.get('window', 5)\n"," workers = kwargs.get('workers', 2)\n","\n"," embedding_model = Word2Vec(walk_path, size=size, window=window, min_count=0, workers=workers, sg=1, hs=1)\n","\n"," return embedding_model"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"B7Kv5bfhaIsh"},"source":["## Hierarchical Softmax\n","First, we'll build the components required to use hierarchical softmax. From the paper:\n","\n","Computing the partition function (normalization factor) is expensive. If we assign the vertices to the leaves of a binary tree, the prediction problem turns into maximizing the probability of a specific path in the tree\n","\n","Thus, instead of having a classifier that predicts probabilities for each word from our vocabulary (besides the one we're currently iterating on), we can structure the loss function as a binary tree where every internal node contains its own binary classifier. Computing the loss (and gradient) can therefore be done in $O(logv)$ predictions rather than $O(v)$ (as is the case with $v$ labels), where $v$ is the number of vertices in our graph."]},{"cell_type":"code","metadata":{"id":"OPTTtTT9aIpg"},"source":["class Tree(ABC): \n"," @staticmethod\n"," def merge(dims, lr, batch_size, left=None, right=None):\n"," if left is not None: left.set_left()\n"," if right is not None: right.set_right()\n"," return InternalNode(dims, lr, batch_size, left, right)\n"," \n"," @staticmethod\n"," def build_tree(nodes, dims, lr, batch_size):\n"," if len(nodes) % 2 != 0: nodes.append(None)\n"," while len(nodes) > 1:\n"," nodes = [Tree.merge(dims, lr, batch_size, nodes[i], nodes[i+1]) for i in range(0, len(nodes) - 1, 2)]\n"," return nodes[0]\n"," \n"," def set_parent(self, t):\n"," self.parent = t\n"," \n"," def set_left(self): self.is_right = False\n"," \n"," def set_right(self): self.is_right = True"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ivhz8tvMaWZg"},"source":["class InternalNode(Tree):\n"," def __init__(self, dims, lr, batch_size, left=None, right=None, parent=None, is_right=None):\n"," self.dims = dims\n"," self.set_left_child(left)\n"," self.set_right_child(right)\n"," self.set_parent(parent)\n"," self.is_right = is_right\n"," self.params = np.random.uniform(size=self.dims) \n"," self.gradients = []\n"," self.lr = lr\n"," self.batch_size= batch_size\n"," \n"," def set_left_child(self, child: Tree):\n"," self.left = child\n"," if self.left is not None:\n"," self.left.set_parent(self)\n"," self.left.set_left()\n"," \n"," def set_right_child(self, child: Tree):\n"," self.right = child\n"," if self.right is not None:\n"," self.right.set_parent(self)\n"," self.right.set_right()\n"," \n"," def set_parent(self, parent: Tree):\n"," self.parent = parent \n"," \n"," def predict(self, embedding, right=True):\n"," d = self.params.dot(embedding) if right else -self.params.dot(embedding)\n"," return 1/(1+np.exp(-d))\n"," \n"," def update_gradients(self, gradient: np.array):\n"," self.gradients.append(gradient)\n"," if len(self.gradients) >= self.batch_size:\n"," avg_gradient = np.stack(self.gradients, axis=0).mean(axis=0)\n"," self.params = self.params - self.lr * avg_gradient\n"," self.gradients = []\n"," \n"," def __eq__(self, other):\n"," return (\n"," self.dims == other.dims and\n"," self.left == other.left and\n"," self.right == other.right and\n"," self.lr == other.lr and\n"," self.batch_size == other.batch_size\n"," )"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"s4cUO_97aZgW"},"source":["class Leaf(Tree):\n"," def __init__(self, vertex, parent: InternalNode = None, is_right = False):\n"," self.parent = parent\n"," self.is_right = is_right \n"," self.vertex = vertex\n"," \n"," def update(self, anchor_vertex):\n"," node = self\n"," gradients = []\n"," total_cost = 0.\n"," emb_grads = []\n"," while node.parent is not None:\n"," is_right = node.is_right\n"," node = node.parent \n"," prob = node.predict(anchor_vertex.embedding, is_right)\n"," log_prob = np.log(prob)\n"," total_cost -= log_prob\n"," u = 1 - prob\n"," node.update_gradients(u*anchor_vertex.embedding)\n"," emb_grads.append(u*node.params)\n"," anchor_vertex.update_embedding(sum(emb_grads))\n"," return total_cost"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mE_0ks6PabTD"},"source":["class Vertex(object):\n"," def __init__(self, dim, lr, batch_size):\n"," self.dim = dim\n"," self.embedding = np.random.uniform(size=dim)\n"," self.lr = lr\n"," self.gradients = []\n"," self.batch_size = batch_size\n"," \n"," def update_embedding(self, gradient: np.array): \n"," self.gradients.append(gradient)\n"," if len(self.gradients) >= self.batch_size:\n"," avg_gradient = np.stack(self.gradients, axis=0).mean(axis=0)\n"," self.embedding = self.embedding - self.lr * avg_gradient\n"," self.gradients = []"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ux9slFZ-adiw"},"source":["v = Vertex(8, 1e-1, 1)\n","v2 = Vertex(8, 1e-1, 1)\n","leaf = Leaf(v)\n","leaf2 = Leaf(v2)\n","i = InternalNode(8, 1e-1, 1, leaf, leaf2)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"3dp2DCMgai2j","executionInfo":{"status":"ok","timestamp":1633185052067,"user_tz":-330,"elapsed":6,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"41d60917-dfc5-4dc2-a955-d6a34fdadf7a"},"source":["before = leaf2.vertex.embedding\n","before_parent = leaf.parent.params\n","print(before)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[0.70466282 0.4704185 0.15045063 0.93010221 0.04333254 0.33917607\n"," 0.3072665 0.97709016]\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6nTMuDVBajHJ","executionInfo":{"status":"ok","timestamp":1633185058468,"user_tz":-330,"elapsed":703,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9b3e5a23-41f5-4d17-eb78-946212109211"},"source":["leaf.update(leaf2.vertex)\n","after = leaf2.vertex.embedding\n","after_parent = leaf.parent.params\n","print(after)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[0.70405907 0.45365486 0.08264726 0.91088379 0.01787853 0.27183648\n"," 0.26271421 0.96241706]\n"]}]},{"cell_type":"markdown","metadata":{"id":"fSmBZypRakhm"},"source":["Leaves 1 and 2 should share parent i. Also, each should have its own vertex (v and v2 respectively)."]},{"cell_type":"code","metadata":{"id":"_pY0JOE_amtM"},"source":["assert leaf.vertex == v\n","assert leaf.vertex != v2\n","assert leaf2.vertex == v2\n","assert leaf2.vertex != v\n","assert leaf.parent == i\n","assert leaf2.parent == i"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NRvDC5rnanwm"},"source":["As a convenience method, we have Tree.merge which should do the same thing as the manual passing to the InternalNode constructor above."]},{"cell_type":"code","metadata":{"id":"fb6NsUPlapka"},"source":["i2 = Tree.merge(8, 1e-1, 1, leaf, leaf2)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2w1yuw3Waqpr"},"source":["assert i2 == i"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"dYGVHNvYar0D"},"source":["We should be able to create an internal node with a single child."]},{"cell_type":"code","metadata":{"id":"yiveInG5atPm"},"source":["i3 = InternalNode(8, 0.01, 1, leaf)\n","assert i3.left == leaf\n","assert i3.right is None"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6lz38Mizaumb"},"source":["We should be able to combine two internal nodes under a third internal node."]},{"cell_type":"code","metadata":{"id":"tN4e7CGDawIn"},"source":["two_internal_nodes = Tree.merge(8, 0.01, 1, i, i2)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iJpxRbHlaxsA"},"source":["assert two_internal_nodes.left == i\n","assert two_internal_nodes.right == i2\n","assert i.parent == two_internal_nodes\n","assert i2.parent == two_internal_nodes"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"fxkGgObGazPL"},"source":["p = Tree.merge(8, 1e-1, 1, leaf, leaf2)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"C6A0p1Mra0jv","executionInfo":{"status":"ok","timestamp":1633185128668,"user_tz":-330,"elapsed":532,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"4d798539-0996-47cd-ee4a-5170709eb09b"},"source":["leaf.parent == leaf2.parent"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["True"]},"metadata":{},"execution_count":16}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hlHh21ooa1u-","executionInfo":{"status":"ok","timestamp":1633185133919,"user_tz":-330,"elapsed":716,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"2a7ca9bc-41dd-4291-8ac0-6e8e7424c848"},"source":["leaf.vertex.embedding"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([0.14202065, 0.94789345, 0.96777175, 0.12845376, 0.33301731,\n"," 0.25128948, 0.98405048, 0.66509886])"]},"metadata":{},"execution_count":17}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Y948sZOYa27u","executionInfo":{"status":"ok","timestamp":1633185139193,"user_tz":-330,"elapsed":424,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9a5e0306-365f-43f2-b244-8636e2448fc8"},"source":["before = leaf2.vertex.embedding.copy()\n","before_parent = leaf.parent.params.copy()\n","leaf.update(leaf2.vertex)\n","after = leaf2.vertex.embedding\n","after_parent = leaf.parent.params\n","(before, after)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(array([0.70405907, 0.45365486, 0.08264726, 0.91088379, 0.01787853,\n"," 0.27183648, 0.26271421, 0.96241706]),\n"," array([ 0.66127081, 0.45437708, 0.00473115, 0.84110065, -0.02988149,\n"," 0.22792929, 0.2618126 , 0.92318361]))"]},"metadata":{},"execution_count":18}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"V7o7Ugtwa4VA","executionInfo":{"status":"ok","timestamp":1633185146746,"user_tz":-330,"elapsed":515,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"48cc8c72-7984-431b-a5a1-b9b080f5c973"},"source":["(before_parent, after_parent)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(array([0.55046806, 0.03145724, 0.89740018, 0.87697905, 0.54720746,\n"," 0.52541901, 0.03329589, 0.53246957]),\n"," array([ 0.48884187, -0.00825111, 0.89016608, 0.79724952, 0.54564255,\n"," 0.5016252 , 0.01030055, 0.44822934]))"]},"metadata":{},"execution_count":19}]},{"cell_type":"code","metadata":{"id":"kaZ9yTnba6I1"},"source":["assert leaf.parent.predict(leaf2.vertex.embedding, right=False) + leaf.parent.predict(leaf2.vertex.embedding)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rZQVRYJNa7Qg","executionInfo":{"status":"ok","timestamp":1633185158126,"user_tz":-330,"elapsed":425,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"27781dd6-6951-4189-a1f2-e9acfcca76a5"},"source":["leaf.parent.predict(leaf2.vertex.embedding)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["0.8188850193603718"]},"metadata":{},"execution_count":21}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zwNqR-Cpa88M","executionInfo":{"status":"ok","timestamp":1633185166246,"user_tz":-330,"elapsed":619,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a4519140-c8c3-465a-c6c5-7ef3db21c18c"},"source":["new_leaf = Leaf(Vertex(8, 0.01, 1))\n","new_leaf2 = Leaf(Vertex(8, 0.01, 1))\n","merged = Tree.merge(8, 0.01, 1, new_leaf, new_leaf2)\n","before1 = new_leaf2.vertex.embedding.copy()\n","new_leaf.update(new_leaf2.vertex)\n","after1 = new_leaf2.vertex.embedding\n","(before1, after1)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(array([0.93594794, 0.00617317, 0.83736121, 0.53829744, 0.86273901,\n"," 0.25883039, 0.64857575, 0.48665158]),\n"," array([ 0.92966094, -0.00183856, 0.83637281, 0.52932184, 0.85800988,\n"," 0.25527446, 0.64106637, 0.47856129]))"]},"metadata":{},"execution_count":22}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"aQNA3TD1a-2D","executionInfo":{"status":"ok","timestamp":1633185172323,"user_tz":-330,"elapsed":645,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9a767b11-b153-49ba-8ede-c2cba83097a5"},"source":["before2 = new_leaf.vertex.embedding.copy()\n","new_leaf2.update(new_leaf.vertex)\n","after2 = new_leaf.vertex.embedding\n","(before2, after2)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(array([0.01577104, 0.73203528, 0.51886159, 0.03187183, 0.52233666,\n"," 0.12200931, 0.67895869, 0.06495115]),\n"," array([0.01467268, 0.73063752, 0.51869031, 0.03030379, 0.52151184,\n"," 0.12138838, 0.67764856, 0.06353787]))"]},"metadata":{},"execution_count":23}]},{"cell_type":"code","metadata":{"id":"Jmyp5cHvbAYP"},"source":["emb_length = 10\n","lr = 1e-3\n","bs = 100\n","v1 = Vertex(emb_length, lr, bs)\n","v2 = Vertex(emb_length, lr, bs)\n","v3 = Vertex(emb_length, lr, bs)\n","random_walk = [v1, v2, v3]\n","leaves = list(map(lambda x: Leaf(x), random_walk))\n","tree = Tree.build_tree(leaves, emb_length, lr, bs)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0fQmteX6bFb_","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1633185240163,"user_tz":-330,"elapsed":662,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e3d0d2a7-05e8-4ab8-8c76-705d4a89b127"},"source":["leaves"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[<__main__.Leaf at 0x7fc68b836d50>,\n"," <__main__.Leaf at 0x7fc68b836a90>,\n"," <__main__.Leaf at 0x7fc68b836250>,\n"," None]"]},"metadata":{},"execution_count":25}]},{"cell_type":"code","metadata":{"id":"8ZjzHNozbFcA","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1633185241985,"user_tz":-330,"elapsed":10,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"15d2626b-42e2-4e7d-9c0c-2b197b5f5478"},"source":["tree.__class__"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["__main__.InternalNode"]},"metadata":{},"execution_count":26}]},{"cell_type":"code","metadata":{"id":"f36Fs_C1bFcC","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1633185242731,"user_tz":-330,"elapsed":9,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"fafb1c70-5d61-469d-f9bd-38d87564e6d7"},"source":["v1.embedding.shape, v2.embedding.shape, v3.embedding.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["((10,), (10,), (10,))"]},"metadata":{},"execution_count":27}]},{"cell_type":"code","metadata":{"id":"MRKD0Yf3bFcD"},"source":["leaf1, leaf2, leaf3, empty_leaf = leaves"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"jKiUuANrbFcD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1633185244387,"user_tz":-330,"elapsed":15,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"b44b881b-9cf7-4b8a-ff19-e633950f81b4"},"source":["leaf3.vertex.embedding"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([0.56054777, 0.8018071 , 0.99571521, 0.35924211, 0.63936966,\n"," 0.31256406, 0.49017507, 0.63942332, 0.55382691, 0.59270016])"]},"metadata":{},"execution_count":29}]},{"cell_type":"code","metadata":{"id":"OHqy3C6TbFcD","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1633185244842,"user_tz":-330,"elapsed":30,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a531b6ba-2394-43f4-cb71-6be33db1b548"},"source":["leaf1.parent, leaf2.parent, leaf3.parent"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(<__main__.InternalNode at 0x7fc68b836610>,\n"," <__main__.InternalNode at 0x7fc68b836610>,\n"," <__main__.InternalNode at 0x7fc68b8360d0>)"]},"metadata":{},"execution_count":30}]},{"cell_type":"markdown","metadata":{"id":"zXoHf1r47yvC"},"source":["## Plots"]},{"cell_type":"code","metadata":{"id":"JpFNd3dSbFcE","colab":{"base_uri":"https://localhost:8080/","height":282},"executionInfo":{"status":"ok","timestamp":1633185246043,"user_tz":-330,"elapsed":1223,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"b39addd4-15e2-42ca-9aaf-99e1b4799ee2"},"source":["costs1 = []\n","costs3 = []\n","combined_cost = []\n","for i in range(10000):\n"," cost1 = leaf1.update(leaf2.vertex)\n"," cost3 = leaf3.update(leaf2.vertex)\n"," if i % bs == 0:\n"," costs1.append(cost1) \n"," costs3.append(cost3)\n"," combined_cost.append(cost1+cost3) \n"," \n","pd.Series(costs1).plot(kind='line')"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":31},{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3xUVf7/8dcnCST0GooEDU0EBEEivQlSFBZQVHAtoCCy0hRdV3fXddX97epaKCKsiAUbdWkiokgPkTKRXqRLEU3o0ovn98cM+42YmAkkucnk/Xw85pGZWzKf+7j65ubcc88x5xwiIhK6wrwuQEREspaCXkQkxCnoRURCnIJeRCTEKehFREJchNcFpKZ06dIuNjbW6zJERHKNxMTEA8656NTW5cigj42NxefzeV2GiEiuYWbfpbVOTTciIiFOQS8iEuIU9CIiIU5BLyIS4hT0IiIhTkEvIhLiFPQiIiEupIL+jXlbWb/vqNdliIjkKCET9EdOnuWTFbvpNjqBSb49XpcjIpJjhEzQFy+Yn1kDm1H/mhI8NWUtz0xdx5nzF7wuS0TEcyET9AClCkfywUMN+EOrKoxfsZu7//M1+46c8rosERFPhVTQA0SEh/GnDtfx1v312ZF8gk4jlrBka7LXZYmIeCaooDezXWa2zsxWm9mvRhszs3vNbG1gmwQzuyHYfbNK+1rlmDGgKWWKRPHAuysYOX8rP/+s+XFFJO/JyBX9zc65us65uFTW7QRaOudqAy8CYzKwb5apHF2Yaf2b0PmGq3j1yy30+cDH0ZPnsrMEERHPZUrTjXMuwTl3OPBxGRCTGb83MxTMH8Gw7nV5oUstlmxNptPIJeqCKSJ5SrBB74AvzSzRzPqms21v4POM7mtmfc3MZ2a+5OTMbVM3Mx5oHMvERxpz/oLjjtEJTFy5O1O/Q0QkpzLn0m+3NrMKzrl9ZlYGmAsMdM4tTmW7m4FRQDPn3MGM7JtSXFycy6qJRw4eP8PgCauJ33aAu+NieKHL9UTlC8+S7xIRyS5mlphW83hQV/TOuX2Bn0nANKBBKl9SBxgLdLkY8sHum51KFY5k3EMNGNi6KpN8e7ljVAK7D570siQRkSyVbtCbWSEzK3LxPdAOWH/JNlcDU4H7nXNbMrKvF8LDjCfaVefdXnHsO3KKjm8sYe7GH70uS0QkSwRzRV8WiDezNcAK4DPn3Bwz62dm/QLb/A0oBYy6pBtlqvtm8jFcttbXlWXWwGZcU6ogD3/g46XPN3P+ws9elyUikqmCaqPPblnZRp+a0+cu8MKsjXyyfDcNKpVk5D31KFM0Ktu+X0TkSl1xG32oi8oXzj9vr83rd9/A2r1HuG1EPF9vP5j+jiIiuYCCPoU7boxhRv9mFC0Qwb1jlzFq4TY9TSsiuZ6C/hLVyxVh5oBm3Fa7PP+e8y0Pf+DjyMmzXpclInLZFPSpKBwZwRv31OOFLrVYvDWZjiPiWbPniNdliYhcFgV9Gi4+TTu5XxMA7vxPAuMSdpETb16LiPwWBX066lYszmeDmtG8WjTPzdzAwPGrOH7mvNdliYgETUEfhOIF8zP2gTie6lCd2ev20/mNeDbtP+Z1WSIiQVHQBykszHi0VVXGP9yI42fO0/XNpUxcuVtNOSKS4ynoM6hh5VLMHtycm2JL8qf/ruOJyWs4eVZNOSKScynoL0PpwMBoj91SjWmr9tFl5FK2/viT12WJiKRKQX+ZwsOMx265lo96N+TwybN0HrmU/ybu9bosEZFfUdBfoaZVS/PZoObUiSnGE5PX8NSUNZw6e8HrskRE/kdBnwnKFo3i4z4NGXBzVSYn7qXrm0vZlnTc67JERAAFfaaJCA/jyfbVef/BBiQfP0PnkfFM/UZNOSLiPQV9Jmt5bTSzBzXn+grFGDJpDX+crKYcEfGWgj4LlCsWxSd9GjKwdVWmfLOXLm/Gq1eOiHhGQZ9FIsLDeKJddT54qAEHj/t75Uz27fG6LBHJgxT0Wax5tWg+H9ycGyoW449T1jJk0mpOaKwcEclGQQW9me0ys3WXzAebcr2Z2Qgz22Zma83sxhTreprZ1sCrZ2YWn1uUKRrFx30aMbiN/wGrziPj2fyDxsoRkeyRkSv6m51zddOYk/BWoFrg1RcYDWBmJYHngIZAA+A5MytxZSXnTuFhxuNtr+Xj3g05dvo8XUYu5ZPlGitHRLJeZjXddAE+cH7LgOJmVh5oD8x1zh1yzh0G5gIdMuk7c6UmVUsze1BzGlQqyZ+nrWPg+FX8dPqc12WJSAgLNugd8KWZJZpZ31TWVwBS3mncG1iW1vI8LbpIJOMebMAf21fn8/U/0HFEPGv3agYrEckawQZ9M+fcjfibaPqbWYvMLsTM+pqZz8x8ycnJmf3rc5ywMKP/zVWZ2LcR5y/8TLfRCbwTv1NNOSKS6YIKeufcvsDPJGAa/vb2lPYBFVN8jgksS2t5at8xxjkX55yLi46ODq76EBAXW5LZg5vT8toyvDhrI33G+Th8QpORi0jmSTfozayQmRW5+B5oB6y/ZLOZwAOB3jeNgKPOuf3AF0A7MysRuAnbLrBMUiheMD9vP1Cf535XkyVbD3Dr8CUs33HQ67JEJEQEc0VfFog3szXACuAz59wcM+tnZv0C28wGdgDbgLeBRwGcc4eAF4GVgdcLgWVyCTPjwaaVmPpoE6LyhXHP28sY/tVWLvysphwRuTKWE9uE4+LinM/3q+76ecbxM+f567R1TF/9PY0ql2RY93qUKxbldVkikoOZWWIa3d/1ZGxOVDgygqHd6/LKnXVYs+cotw5fzLxNP3pdlojkUgr6HMrMuCuuIrMGNaN8sQL0Hufj+U83cOa8RsIUkYxR0OdwVaILM/XRJvRqEst7S3dxx6gEdiRrUhMRCZ6CPheIyhfO3zvXYuwDcXx/5BSd3ohnSuJe9bkXkaAo6HORW2qW5fPBLahdoRhPTl7DYxNXa/gEEUmXgj6XKVcsik8ebsSQttfy6Zrv6TgintV7NHyCiKRNQZ8LhYcZg9pUY9Ijjbnws+PO0QmMXridn9XnXkRSoaDPxeJiSzJ7UHPa1yrHy3M288C7K0g6dtrrskQkh1HQ53LFCuZj5O/r8dIdtfF9d4gOw5eoz72I/IKCPgSYGT0aXM2sgc0oVzSK3uN8/H3mBk6fU597EVHQh5SqZYowrX8THmpaifcTdtH1zaVs+fEnr8sSEY8p6ENMZEQ4f/tdTd578CYOHD/D796I58Ovd6nPvUgepqAPUTdXL8Png1vQqHIpnp2xgYc/8HHw+BmvyxIRDyjoQ1h0kUje63UTf+tUk8VbDtBh+BKWbA392btE5JcU9CEuLMx4qFklpvdvSvEC+bj/nRX8Y9ZGDY4mkoco6POImlcVZeaAZtzX6GrGxu/k9jcT2JakG7UieYGCPg8pkD+cf3StzdgH4vjh2Gk6jojnw2Xf6UatSIhT0OdBt9Qsy5zBzWlQqSTPTl+vG7UiIU5Bn0eVKRrFuAcb8GzgRm37YUtY+G2S12WJSBYIOujNLNzMVpnZrFTWDTWz1YHXFjM7kmLdhRTrZmZW4XLlwsKM3s0qMWNAU0oVyk+v91bqiVqREBSRgW0HA5uAopeucM49fvG9mQ0E6qVYfco5V/eyK5QsV6N8UWYMaMpLn2/m/YRdJGw/wPAe9ahR/lenWkRyoaCu6M0sBugIjA1i83uA8VdSlGS/i7NYjXuoAYdPnqPLyKWMXbJDQx+LhIBgm26GAU8BP//WRmZ2DVAJmJ9icZSZ+cxsmZl1/Y19+wa28yUn66Eer7S8Npo5g5vTsno0//hsE/e/u5z9R095XZaIXIF0g97MOgFJzrnEIH5fD2CKcy5lI+81zrk44PfAMDOrktqOzrkxzrk451xcdHR0MLVLFilVOJIx99fnpTtq8813R+gwbAmfrd3vdVkicpmCuaJvCnQ2s13ABKC1mX2UxrY9uKTZxjm3L/BzB7CQX7bfSw51cejj2YObE1u6EP0/+YYhkzRHrUhulG7QO+eecc7FOOdi8Qf5fOfcfZduZ2bXASWAr1MsK2FmkYH3pfH/o7Exk2qXbFCpdCGm9GvMoDbVmL5qH7cOX8KKnYe8LktEMuCy+9Gb2Qtm1jnFoh7ABPfLxyxrAD4zWwMsAF5yzinoc5l84WEMaXstk/s1IcyM7mO+5uU5mzl7/jdv2YhIDmE58fH3uLg45/P5vC5DUnH8zHle/HQjE317qHVVUYZ1r0u1skW8LkskzzOzxMD90F/Rk7GSIYUjI3j5zjq8dX999h89Tac34nlv6U51wxTJwRT0clna1yrHF4+1oGnV0jz/6UZ6vreCH46e9rosEUmFgl4uW3SRSN7pGcc/b6+Nb9dh2g9bzMw133tdlohcQkEvV8TM+H1DfzfMytGFGDR+FQPHr+LIybNelyYiAQp6yRSVShdi8iONeaLttXy+bj/thy3WtIUiOYSCXjJNRHgYA9tUY9qjTSkcGcH976zguRnrOXVWo2GKeElBL5mudkwxPhvUnAebxjLu6+/oOGIJq/ccSX9HEckSCnrJElH5wnnud7X4uE9DTp+7QLfRCbw+dwvnLughK5HspqCXLNW0amnmPN6CLnWvYsS8rdwxKoGtP2pScpHspKCXLFc0Kh+v312X/9x3I/uOnKLjG/Ea614kGynoJdt0uL48XzzWgpbX+se6v+ftZew5dNLrskRCnoJeslV0Ef9Y96/cWYcN3x+jw7DFTFixm5w45pJIqFDQS7YzM+6Kq8icx5pTJ6Y4T09dx0Pvr+THYxpCQSQrKOjFMzElCvJxn4b8/Xc1Sdh+kHZD/UMo6OpeJHMp6MVTYWFGr6aVfjGEwoBPVnHohIZQEMksCnrJEapEF2byI415qkN1vtz4A+2GLuLLDT94XZZISFDQS44RER7Go62qMnNAM8oUiaLvh4kMmbSao6c0T63IlVDQS45To3xRpvdvyqA21Zix+nvaD13Mwm+TvC5LJNdS0EuOlD/CP0/ttEebUCQqgl7vreSZqWv56bSu7kUyKuigN7NwM1tlZrNSWdfLzJLNbHXg1SfFup5mtjXw6plZhUveUCemOJ8ObEa/llWYuHIPHYYtYem2A16XJZKrZOSKfjCw6TfWT3TO1Q28xgKYWUngOaAh0AB4zsxKXHa1kidF5Qvn6VuvY3K/JkRGhHHv2OX8dfo6Tpw573VpIrlCUEFvZjFAR2BsBn9/e2Cuc+6Qc+4wMBfokMHfIQJA/WtKMHtwc/o0q8THy3fTYfhivt5+0OuyRHK8YK/ohwFPAb81xmw3M1trZlPMrGJgWQVgT4pt9gaW/YqZ9TUzn5n5kpM1M5GkLipfOH/tVJNJjzQm3Ix73l7GczPWc/Ksru5F0pJu0JtZJyDJOZf4G5t9CsQ65+rgv2ofl9FCnHNjnHNxzrm46OjojO4uecxNsSX5fHCL/01ucuvwJSzfoat7kdQEc0XfFOhsZruACUBrM/so5QbOuYPOuTOBj2OB+oH3+4CKKTaNCSwTuWIF8vsnN5nYtxHOQfcxy/j7zA26uhe5RLpB75x7xjkX45yLBXoA851z96XcxszKp/jYmf+7afsF0M7MSgRuwrYLLBPJNA0rl2LOY83p1SSW9xN26epe5BKX3Y/ezF4ws86Bj4PMbIOZrQEGAb0AnHOHgBeBlYHXC4FlIpmqYP4I/t65FhN0dS/yK5YTRwqMi4tzPp/P6zIklzp59jz/nvMt7yfsomLJArzcrQ5NqpT2uiyRLGVmic65uNTW6clYCTkXr+4v9sz5/dv+fvfH1e9e8igFvYSsBpX8PXN6B/rdtx+6mPiteqpW8h4FvYS0AvnDebZTTab0a0xkRBj3vbOcZ6au5ZjGzJE8REEveUL9a0oye3BzHmlRmYkr99B+6GIWaERMySMU9JJnROUL55nbavDfPzShcGQED763kiGTVnPkpGazktCmoJc8p97VJZg1qBkDW1dlxurvaTt0MXPWazYrCV0KesmTIiPCeaJddWYOaEqZIpH0+yiR/h9/Q/JPZ9LfWSSXUdBLnlbrqmJM79+UP7avztyNP9J26CKmrdpLTny+RORyKeglz8sXHkb/m6vy2aBmVCpdiMcnruGh91fy/ZFTXpcmkikU9CIB1coWYUq/JjzbqSbLdhyi3dDFfLTsO37+WVf3krsp6EVSCA8zejerxBePteCGisX46/T13PP2MnYeOOF1aSKXTUEvkoqrSxXko94NeblbbTbuP0aHYYt5a9F2zl/4rbl3RHImBb1IGsyM7jddzVdDWtLy2mj+9flmbh+VwMbvj3ldmkiGKOhF0lG2aBRv3V+fUffeyP6jp+g8Mp5XvtjM6XMXvC5NJCgKepEgmBm31S7PV0Na0rVeBd5csJ3bRixhxU5NryA5n4JeJAOKF8zPq3fdwIe9G3D2/M/c/dbX/GXaOg2SJjmagl7kMjSvFs2Xj/uHQB6/YjftXl/M3I0/el2WSKoU9CKXqWD+CJ7tVJNpjzaleMF8PPyBj0c/TiTpp9NelybyC0EHvZmFm9kqM5uVyrohZrbRzNaa2TwzuybFugtmtjrwmplZhYvkFDdULM6nA5vxx/bV+WpTEre8togJK3ZrGAXJMTJyRT8Y2JTGulVAnHOuDjAF+HeKdaecc3UDr86p7y6Su10cRmHO4ObUKF+Up6euo8eYZexIPu51aSLBBb2ZxQAdgbGprXfOLXDOnQx8XAbEZE55IrlL5ejCjH+4Ef+6I/Cg1fAljJy/lbPn9aCVeCfYK/phwFNAMP+19gY+T/E5ysx8ZrbMzLqmtZOZ9Q1s50tOTg6yLJGcJyzMuKfB1cwb0pJbapTh1S+38Ls34vlm92GvS5M8Kt2gN7NOQJJzLjGIbe8D4oBXUiy+xjkXB/weGGZmVVLb1zk3xjkX55yLi46ODq56kRysTNEoRt1bn7cfiOPY6XN0G53AczPW85O6Yko2C+aKvinQ2cx2AROA1mb20aUbmdktwF+Azs65/83e4JzbF/i5A1gI1LvyskVyj7Y1yzJ3SEt6No7lg2Xf0fb1xXyxQTNaSfZJN+idc88452Kcc7FAD2C+c+6+lNuYWT3gLfwhn5RieQkziwy8L43/H42NmVi/SK5QODKCv3euxdQ/NKF4wXw88mEij3zo44ej6oopWe+y+9Gb2QtmdrEXzStAYWDyJd0oawA+M1sDLABecs4p6CXPqnd1CT4d2Iw/dbiOhd8mc8vrixiXsIsLGvNespDlxL6+cXFxzufzeV2GSJb67uAJ/jp9PUu2HqBuxeL8647a1Chf1OuyJJcys8TA/dBf0ZOxIh65plQhPnioAcO612XPoZN0eiOef32+iVNnNSqmZC4FvYiHzIyu9Sow74mW3FU/hrcW7aDt0EUs/DYp/Z1FgqSgF8kBihfMz0vd6jCxbyMiI8Lo9d5K+n/yDUnHdLNWrpyCXiQHaVi5FLMHN+eJttcyd+OPtHltER9+rZu1cmUU9CI5TGREOAPbVAtMUF6cZ2ds4I7RCWz4/qjXpUkupaAXyaEqlS7Eh739N2v3HT5J55FL+cesjZw4c97r0iSXUdCL5GAXb9Z+NaQld8fFMDZ+J7e8vkhP1kqGKOhFcoHiBfPzrzvqMKVfY4oV8D9Z22ecj72HT6a/s+R5CnqRXCQutiSfDmzGM7dex9JtB2j7+mLeWrSdcxc0DLKkTUEvksvkCw/jkZZVmDukBU2rluZfn2+m04h4fLsOeV2a5FAKepFcKqZEQcb2jGPM/fU5fuY8d/7na56asoZDJ856XZrkMAp6kVyuXa1yzB3SgkdaVmbqN/to89pCJq7czc/qey8BCnqREFAwfwTP3FqDzwY1p2qZwvzpv+u4662v2bT/mNelSQ6goBcJIdXLFWHSI4155c467Dxwgk5vxPPirI0cV9/7PE1BLxJizIy74ioy/4mW3B1XkXeX7qTNawuZtfZ7cuKw5JL1FPQiIcrf9742U//QhNKFIxnwySoeeHcFO5KPe12aZDMFvUiIq3d1CWYOaMbznWuxes8ROgxbwqtffKtx7/MQBb1IHhAeZvRsEsv8J1rRqU55Ri7Yxi2vL+LLDT+oOScPUNCL5CHRRSJ5vXtdJvZtRKHIcPp+mEjvcT6+O3jC69IkCwUd9GYWbmarzGxWKusizWyimW0zs+VmFpti3TOB5d+aWfvMKVtErkTDyqX4bFBz/tqxBst3HKTt0MUMnbuF0+fUnBOKMnJFPxjYlMa63sBh51xVYCjwMoCZ1QR6ALWADsAoMwu//HJFJLPkCw+jT/PKzHuiFe1rlWP4vK20HbqIeZt+9Lo0yWRBBb2ZxQAdgbFpbNIFGBd4PwVoY2YWWD7BOXfGObcT2AY0uLKSRSQzlSsWxRv31OOThxsSGRFO73E++oxbye6DGhkzVAR7RT8MeApIa4i8CsAeAOfceeAoUCrl8oC9gWW/YmZ9zcxnZr7k5OQgyxKRzNKkSmlmD2rOn2+7joTtB7ll6CI154SIdIPezDoBSc65xKwsxDk3xjkX55yLi46OzsqvEpE05I8Io2+LKsy/pDln7sYf1TsnFwvmir4p0NnMdgETgNZm9tEl2+wDKgKYWQRQDDiYcnlATGCZiORgKZtzoiLCefgDHw++v5KdB9Q7JzdKN+idc88452Kcc7H4b6zOd87dd8lmM4Gegfd3BrZxgeU9Ar1yKgHVgBWZVr2IZKkmVUoze7C/d45v12HaD13MK19s5uRZjZ2Tm1x2P3oze8HMOgc+vgOUMrNtwBDgaQDn3AZgErARmAP0d86pwU8kF7nYO2f+ky3pVKc8by7YTpvXFmnsnFzEcuKJiouLcz6fz+syRCQVvl2H+NuMDWzcf4zGlUvx9861qF6uiNdl5Xlmluici0ttnZ6MFZEMuThv7Ytdr2fj/mPcNmIJz3+6gaOnznldmqRBQS8iGRYeZtzf6BoWPNmK7jdV5P2EXbR+dSGTVu7RzFY5kIJeRC5byUL5+efttfl0QDMqlS7EU/9dy+2jlrJq92GvS5MUFPQicsWur1CMyf0aM6x7XfYfPc3toxJ4cvIakn467XVpgoJeRDKJmdG1XgXmP9mKR1pWZsbqfbR+dRFvL97B2fNpPVQv2UFBLyKZqnCkf6LyLx9vSYNKJfl/szfRYfhiFnyb5HVpeZaCXkSyRKXShXi310281+smcPDgeyt5SE/XekJBLyJZ6ubryjDnsRb8+bbrWLHzEO2GLuJfszfx02l1x8wuCnoRyXL/GyztyZZ0rVuBtxbv4OZXFzHJp+6Y2UFBLyLZpkyRKF656wZm9G/K1SUL8NSUtXQdtZTE7w55XVpIU9CLSLa7oWJx/vuHJgzrXpekY2foNvprBo1fxfdHTnldWkhS0IuIJ/6vO2ZLBrWuypwNP9D6tYUM+2oLp85q7MPMpKAXEU8VzB/BkHbVmTekJW2uK8uwr7bS5rWFzFi9T6NjZhIFvYjkCBVLFuTNe29kYt9GlCiUn8ETVtNtdAJr9hzxurRcT0EvIjlKw8qlmDmgGS93q83uQyfp8uZShkxczQ9HNZzC5VLQi0iOEx5mdL/pahY82Yp+Laswa+1+bn51ISPmbVX7/WVQ0ItIjlUkKh9P33odXw1pSavq0bw+d4va7y+Dgl5EcryrSxVk9H31mZCi/f6O0Ql8o+GQg5Ju0JtZlJmtMLM1ZrbBzJ5PZZuhZrY68NpiZkdSrLuQYt3MzD4AEck7GlUuxacDmvHKnXXYd/gUd4xKYND4Vew9fNLr0nK0dOeMNTMDCjnnjptZPiAeGOycW5bG9gOBes65hwKfjzvnCmekKM0ZKyLpOXHmPP9ZtJ0xi3cA0Kd5Jf7QqiqFIyM8rswbVzRnrPM7HviYL/D6rX8d7gHGZ7hKEZEMKBQZwRPtqjP/yVZ0uL4cby7YTqtXFjJhxW4uaPycXwiqjd7Mws1sNZAEzHXOLU9ju2uASsD8FIujzMxnZsvMrOtvfEffwHa+5OTkDByCiORlFYoXYHiPekx7tAnXlCrI01PX0XHEEpZsVY5clG7TzS82NisOTAMGOufWp7L+T0CMc25gimUVnHP7zKwy/n8A2jjntv/W96jpRkQuh3OOz9bt5+U5m9lz6BQ3V4/mLx1rULVMEa9Ly3JX1HSTknPuCLAA6JDGJj24pNnGObcv8HMHsBCol5HvFBEJlpnRqc5VfDWkJX++7Tp8uw7TftgS/jp9HQeOn/G6PM8E0+smOnAlj5kVANoCm1PZ7jqgBPB1imUlzCwy8L400BTYmDmli4ikLjIinL4tqrDoqZu5r+HVjF+xh1avLGTUwm2cPpf3HrgK5oq+PLDAzNYCK/G30c8ysxfMrHOK7XoAE9wv24JqAD4zW4P/L4GXnHMKehHJFiUL5ef5LtfzxWMtaFS5JP+e8y1tXlvE9FX78tSEJxlqo88uaqMXkayQsP0A/5y9ifX7jlG7QjH+0rEGjSqX8rqsTJFpbfQiIrlZkyqlmdm/Ga/ddQMHjp+hx5hl9BnnY1vS8fR3zsUU9CKSp4SFGd3qx7DgyVb8sX11lu04SPthi3l2+vqQvWGrphsRydMOHD/DiHlb+Xj5bgrkC6dfy8r0blaZAvnDvS4tQ9R0IyKShtKFI3mhy/V8+XgLmlQpxatfbuHmVxcyybcnZJ6wVdCLiABVogsz5oE4Jj3SmLLFonhqylo6jljCwm+Tcv2QyAp6EZEUGlQqyfRHmzDy9/U4efYCvd5byX3vLGf9vqNel3bZFPQiIpe4+ITt3CEt+Funmmz8/hid3ojnsQmr2HMo9w2JrJuxIiLpOHb6HKMXbufd+J04B/c3voYBN1elRKH8Xpf2P791M1ZBLyISpP1HTzF07hYmJ+6lcGQEj7aqyoNNY4nK530PHQW9iEgm+vaHn3h5zmbmb06ifLEoHm97Ld1ujCE8zDyrSd0rRUQyUfVyRXi3101M6NuIMkX9PXRuHb6YeZt+zJE9dBT0IiKXqVHlUkx/tAmj7r2Rcxccvcf56P7WMhK/y1mTlivoRUSugJlxW+3yfPl4C17sej07Dpyg2+gE+n7gY1vST16XB6iNXkQkU504c55343fy1uIdnPrihJEAAAZgSURBVDx7nrvqV+SxttUoX6xAln6vbsaKiGSzg8fPMHLBNj5a9h1hZvRqEssfWlWheMGs6ZKpoBcR8cieQycZOncL01bvo0hkBP1aVeHBJpUyfdA0Bb2IiMc27T/Gq198y7zNSZQpEsngW6pxd1xF8oVnzq1Sda8UEfFYjfJFeafXTUzu15irSxbkL9PW0/b1Rcxc832WT2uooBcRyUY3xZZkcr/GvNMzjsiIcAaNX0WnN+KzdJTMdIPezKLMbIWZrTGzDWb2fCrb9DKzZDNbHXj1SbGup5ltDbx6ZvYBiIjkNmZGmxplmT24OUO738Cx0+fo9d5KeoxZxqmzFzL9+yKC2OYM0No5d9zM8gHxZva5c27ZJdtNdM4NSLnAzEoCzwFxgAMSzWymcy5nPU0gIuKB8DDj9noxdKx9FRNW7mbDvmNZMrNVukHv/H9LXJw5N1/gFezfF+2Buc65QwBmNhfoAIzPeKkiIqEpf0QYDzSOzbLfH1QbvZmFm9lqIAl/cC9PZbNuZrbWzKaYWcXAsgrAnhTb7A0sS+07+pqZz8x8ycnJGTgEERH5LUEFvXPugnOuLhADNDCz6y/Z5FMg1jlXB5gLjMtoIc65Mc65OOdcXHR0dEZ3FxGRNGSo141z7giwAH/zS8rlB51zZwIfxwL1A+/3ARVTbBoTWCYiItkkmF430WZWPPC+ANAW2HzJNuVTfOwMbAq8/wJoZ2YlzKwE0C6wTEREskkwvW7KA+PMLBz/PwyTnHOzzOwFwOecmwkMMrPOwHngENALwDl3yMxeBFYGftcLF2/MiohI9tAQCCIiIUBDIIiI5GEKehGREJcjm27MLBn47jJ3Lw0cyMRycoO8eMyQN487Lx4z5M3jzugxX+OcS7Vveo4M+ithZr602qlCVV48Zsibx50Xjxny5nFn5jGr6UZEJMQp6EVEQlwoBv0YrwvwQF48Zsibx50Xjxny5nFn2jGHXBu9iIj8Uihe0YuISAoKehGREBcyQW9mHczsWzPbZmZPe11PVjGzima2wMw2BqZ2HBxYXtLM5gambJwbGEQupATmRVhlZrMCnyuZ2fLAOZ9oZvm9rjGzmVnxwBwPm81sk5k1DvVzbWaPB/7bXm9m4wPTmYbcuTazd80syczWp1iW6rk1vxGB419rZjdm5LtCIugDA669CdwK1ATuMbOa3laVZc4DTzjnagKNgP6BY30amOecqwbMC3wONYP5v5FRAV4GhjrnqgKHgd6eVJW1hgNznHPXATfgP/6QPddmVgEYBMQ5564HwoEehOa5fp9Lhnwn7XN7K1At8OoLjM7IF4VE0AMNgG3OuR3OubPABKCLxzVlCefcfufcN4H3P+H/H78C/uO9OOHLOKCrNxVmDTOLATrin+8AMzOgNTAlsEkoHnMxoAXwDoBz7mxgToiQPtf4R9UtYGYRQEFgPyF4rp1zi/GP9ptSWue2C/CB81sGFL9kePjfFCpBH/SUhaHEzGKBesByoKxzbn9g1Q9AWY/KyirDgKeAnwOfSwFHnHPnA59D8ZxXApKB9wJNVmPNrBAhfK6dc/uAV4Hd+AP+KJBI6J/ri9I6t1eUcaES9HmOmRUG/gs85pw7lnJdYEL3kOk3a2adgCTnXKLXtWSzCOBGYLRzrh5wgkuaaULwXJfAf/VaCbgKKMSvmzfyhMw8t6ES9HlqykIzy4c/5D92zk0NLP7x4p9ygZ9JXtWXBZoCnc1sF/5mudb4266LB/68h9A853uBvc655YHPU/AHfyif61uAnc65ZOfcOWAq/vMf6uf6orTO7RVlXKgE/UqgWuDOfH78N29melxTlgi0Tb8DbHLOvZ5i1UygZ+B9T2BGdteWVZxzzzjnYpxzsfjP7Xzn3L345y++M7BZSB0zgHPuB2CPmVUPLGoDbCSEzzX+JptGZlYw8N/6xWMO6XOdQlrndibwQKD3TSPgaIomnvQ550LiBdwGbAG2A3/xup4sPM5m+P+cWwusDrxuw99mPQ/YCnwFlPS61iw6/lbArMD7ysAKYBswGYj0ur4sON66gC9wvqcDJUL9XAPP45+Xej3wIRAZiucaGI//PsQ5/H+99U7r3AKGv2fhdmAd/l5JQX+XhkAQEQlxodJ0IyIiaVDQi4iEOAW9iEiIU9CLiIQ4Bb2ISIhT0IuIhDgFvYhIiPv/qdpbRc7ZfvYAAAAASUVORK5CYII=\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"id":"Nyl5xXrybFcE","colab":{"base_uri":"https://localhost:8080/","height":282},"executionInfo":{"status":"ok","timestamp":1633185246914,"user_tz":-330,"elapsed":894,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e58ab367-c65b-44d0-fb5f-b72d48c857c6"},"source":["pd.Series(costs3).plot(kind='line')"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":32},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"id":"_j2khUZ1bFcE","colab":{"base_uri":"https://localhost:8080/","height":282},"executionInfo":{"status":"ok","timestamp":1633185248847,"user_tz":-330,"elapsed":11,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"3f3e35d4-c6fe-4762-f26a-39e318594dc2"},"source":["pd.Series(combined_cost).plot(kind='line')"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":33},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"id":"EnNn5u0xbFcF"},"source":["emb_length, lr, bs = 10, 1e-4, 100\n","leaves = [Vertex(emb_length, lr, bs) for i in range(100)]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"6EKibUn-bFcF"},"source":["leaves = [Leaf(v) for v in leaves]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"flNh5_jRbFcF"},"source":["tree = Tree.build_tree(leaves, emb_length, lr, bs)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gma5jbSzbFcF"},"source":["chosen_leaf = leaves[20]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Pam4uPUBbFcG","colab":{"base_uri":"https://localhost:8080/","height":72},"executionInfo":{"status":"ok","timestamp":1633185274998,"user_tz":-330,"elapsed":23762,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"b97035cc-c4a5-4416-9915-911b8e5741ea"},"source":["#slow\n","costs = []\n","num_iter = 3000\n","epoch_costs = [] \n","for it in range(num_iter):\n"," for i in range(100):\n"," if i == 20:\n"," continue\n"," costs.append(leaves[i].update(chosen_leaf.vertex)) \n"," epoch_costs.append(np.mean(costs))\n"," costs = []\n","s = pd.Series(epoch_costs)\n","s.plot(kind='line')"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":38},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"AdYUmWzbbFcG"},"source":["This is an interesting result -- it seems a little unusual that we would see training loss going up, but some things to consider:\n","* In the \"real\" version, the leaf embeddings are (hopefully) going to have some relationship with the internal node model parameters. In this toy version, we've uniformly initialized all parameters and then trained the model on every single leaf for many iterations. It's basically learning how to optimize random noise.\n","* We're using plain vanilla batch GD here, with no learning rate annealing (or any of the wide number of GD enhancements that exist). It's very possible that we're getting gradient explosions / divergence towards the end here. "]},{"cell_type":"markdown","metadata":{"id":"4Zu0lhlHbFcG"},"source":["The goal of hierarchical softmax is to make the scoring function run in $O(logv)$ rather than $O(v)$ by organizing the nodes as a binary tree with a binary classifier at each internal node. At a high level, we follow these steps:\n","1. We identify a leaf that is contained within the window of our vertex within the current random walk\n","2. We take that leaf's parent and compute the probability of having followed the correct path (left or right) to the leaf we identified in step 1 by using the model parameters for this internal node combined with the features for the current vertex (which is a row in $\\Phi$).\n","3. We repeat step 2 for all internal nodes until we get to the root\n","4. The product of all of the internal probabilities gives us the probability of seeing a co-occurrence of the neighbor node given what we know about the node we're exploring\n","5. $-logPr(u_k|\\Phi(v_j))$ is our loss function, where $Pr(u_k|\\Phi(v_j))$ is the probability we calculated in step 4\n","6. We use the loss in step 5 to perform a gradient descent step updating both the parameters of our model and $\\Phi(v_j)$:\n","\n","$$\\theta \\leftarrow \\theta - \\alpha_\\theta * \\frac{\\partial J}{\\partial \\theta}$$\n","
\n","$$\\Phi \\leftarrow \\Phi - \\alpha_\\Phi * \\frac{\\partial J}{\\partial \\Phi}$$\n","\n","Where $\\theta$ represents all of the parameters of all of the models in the internal nodes of the tree, and $\\Phi$ represents the latent representation of the current vertex."]}]}