{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Minimum-Cost Bipartite Matching" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Hungarian Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Hungarian algorithmを実装してみた。\n", "効率が悪い(good pathsを探すときの辺の選び方はもっと賢くできる)ので要改良\n", "\n", "似非さんにより「パパ活アルゴリズム」と命名された。" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import networkx as nx\n", "import matplotlib.pyplot as plt\n", "from collections import deque\n", "import numpy as np\n", "\n", "\n", "def Initialize_H(G):\n", " '''\n", " Initialize a matching M and prices of all vertices.\n", "\n", " Parameters\n", " ----------\n", " G = (V+W, E): a bipartite graph\n", " V+W: vertices\n", " price\n", " E: edges\n", " cost(nonnegative)\n", " M: a list of the edges which are a subset of E, a graph attribute\n", "\n", " Return\n", " ------\n", " H: initialized graph with zero prices and an empty matching M\n", " '''\n", "\n", " H = G.copy()\n", "\n", " # Initialize a matching\n", " # graph attributeとして定義\n", " H.graph['M_edges'] = []\n", " H.graph['M_vertices'] = []\n", "\n", " # Initialize prices\n", " for v in H.nodes():\n", " H.node[v]['price'] = 0\n", "\n", " return H\n", "\n", "\n", "def BFS_GP(G, r, V, W):\n", " '''\n", " Implementing BFS modified for a good path\n", " Need to be modified in the future(BFS is poorly implemented.)\n", "\n", " Parameters\n", " ----------\n", " G=(V+W, E): a bipartite graph with a current matching M\n", " V+W: vertices\n", " p: prices\n", " M: a current matching\n", " r: the first unmatched vertex r of V(the root of a BFS tree)\n", " Return\n", " ------\n", " a list (path, S, N, match)\n", " path: a list of vertices representing a stucked path P, [r, ..., v]\n", " S: even level vertices of the BFS tree\n", " N: odd level ones\n", " match: False if the search tree contains an unmatched vertex w in W\n", " '''\n", "\n", " def candidate(G, v, level, past):\n", " '''return a list of vertices(say, candidates) that satisfy the condition(*):\n", " 1. adding (v,p) to the current path keep the path M-alternating\n", " 2. (v,p) is tight\n", " 3. p has not been reached before\n", " If the length of the returned list is zero, BFS is stucked.\n", "\n", " Paremeters\n", " ----------\n", " G: a bipartite graph\n", " v: a current node\n", " level: a current level of BFS(v lies in level i of the BFS tree)\n", "\n", " Return\n", " ------\n", "\n", " '''\n", "\n", " # Check if stucked or not\n", " # level: odd のとき(v is in W)は、matchingに含まれる枝の中から探す必要あり\n", " # level: even のとき(v is in V)は、まだ訪れていない(pastを使う)tight edgesから探す必要あり\n", " if level % 2 != 0:\n", " # このときcandidatesには多くとも一点しか(vとmatchしている点)含まれないはず\n", " candidates = [p for p in G.neighbors(\n", " v) if ((v, p) in G.graph['M_edges'])]\n", " else:\n", " candidates = [p for p in G.neighbors(v) if ((p not in past) and (\n", " G[v][p]['cost'] - G.node[v]['price'] - G.node[p]['price'] == 0))]\n", "\n", " return candidates\n", "\n", " # 過去に訪れた点を格納するためのリスト\n", " past = []\n", " # これから訪れる点を格納するためのqueue\n", " future = deque()\n", " # pathを格納するためのリスト\n", " path = []\n", " # BFS treeを格納するためのlist, S: even level, N: odd level\n", " S = []\n", " N = []\n", "\n", " # set a distance from the root\n", " for p in G.nodes():\n", " if (p != r):\n", " G.node[p]['distance'] = float('inf')\n", " else:\n", " G.node[p]['distance'] = 0\n", "\n", " # set a current node v as r\n", " v = r\n", " # level of a BFS tree\n", " level = 0\n", " # candidates, a list of nodes\n", " candidates = candidate(G, v, level, past)\n", "\n", " # for debugging\n", " print('before growing the BFS tree')\n", " print('the root is: ' + v)\n", " print('its price: ' + str(G.node[v]['price']))\n", " print('its candidate: ' + str(candidates))\n", "\n", " while(len(candidates) != 0 and not (set(V) <= set(G.graph['M_vertices']))):\n", " # level = 0, or unstucked and there is a unmatched vertex of V\n", " # v: current vertex\n", " # level: current level\n", "\n", " # for debugging\n", " print('Grow the tree!')\n", "\n", " past.append(v)\n", " if (level % 2 == 0):\n", " S.append(v)\n", "\n", " # for debugging\n", " print('level: ' + str(level))\n", " print('add ' + v + ' to S')\n", " else:\n", " N.append(v)\n", "\n", " # for debugging\n", " print('level: ' + str(level))\n", " print('add ' + v + ' to N')\n", "\n", " level += 1\n", "\n", " # for debugging\n", " print('current node: ' + v)\n", " print('candidates: ' + str(candidates))\n", " print('future: ' + str(future))\n", " print('distance: ' + str(G.node[v]['distance']))\n", "\n", " for p in candidates:\n", " if p not in future:\n", " future.append(p)\n", " print ('append ' + p + ' to future')\n", " G.node[p]['distance'] = G.node[v]['distance'] + 1\n", " print ('distance: ' + str(G.node[p]['distance']))\n", "\n", " if(len(future) == 0):\n", " break\n", " else:\n", " v = future.popleft()\n", " print('the next node: ' + v)\n", " print('level: ' + str(level))\n", " print('past: ' + str(past))\n", " candidates = candidate(G, v, level, past)\n", " print('candidates: ' + str(candidates))\n", "\n", " if (len(candidates) == 0):\n", " if (level % 2 == 0):\n", " S.append(v)\n", " else:\n", " N.append(v)\n", "\n", " # After terminating the while loop, BFS search is stucked\n", " # First, check if v is another unmatched vertex or not\n", " match = False\n", " if (v in G.graph['M_vertices']):\n", " # In case that the last vertex v is matched\n", " match = True\n", "\n", " # Construct a good path from r to v\n", " # 以下、sからeへのpathが存在する場合\n", " # 終点から遡ってpathを形成する\n", " pp = v\n", "\n", " # for debugging\n", " loopCount2 = 0\n", " print ('initial pp: ' + pp)\n", "\n", " while (1):\n", " loopCount2 += 1\n", " print ('loopCount2: ' + str(loopCount2))\n", " if loopCount2 == 20:\n", " return 'bugbug'\n", "\n", " path.insert(0, pp)\n", " print('add ' + pp +' to path')\n", " if pp == r:\n", " break\n", "\n", " pred = G.predecessors(pp)\n", "\n", " for p in pred:\n", " if (G.node[p]['distance'] == G.node[pp]['distance'] - 1\n", " and (G[p][pp]['cost'] - G.node[p]['price'] - G.node[pp]['price'] == 0)):\n", " print('go backward')\n", " print(pp)\n", " print(G.node[pp]['distance'])\n", " pp = p\n", " print(pp)\n", " print(G.node[pp]['distance'])\n", " break\n", "\n", " return [path, S, N, match]\n", "\n", "\n", "def Hungarian(G, V, W):\n", " '''Solving min-cost bipartite matching problem\n", "\n", " Parameters\n", " ----------\n", " G=(V+W, E): a bipartite graph\n", " V+W: vertices, |V| = n\n", " prices\n", " E: edges\n", " cost>=0\n", " the reduced cost of (v,w) = cost of (v,w) - price of v - price of w\n", " M: matching\n", " Invariants:\n", " 1. the reduced costs of all edges in G >= 0\n", " 2. the reduced costs of all edges in M = 0(tight)\n", "\n", " Return\n", " ------\n", " the min-cost matching M of G\n", " '''\n", "\n", " H = Initialize_H(G)\n", "\n", " # for debugging\n", " exLoopCount = 0\n", " NUM = 30\n", "\n", " while (len(H.graph['M_vertices']) != len(V) + len(W)):\n", " # a current matching M is not a perfect matching\n", "\n", " # for debugging\n", " exLoopCount += 1\n", " print('exLoopCount: ' + str(exLoopCount))\n", " if (exLoopCount == NUM):\n", " break\n", "\n", " # pick the first unmatched node r of V\n", " for p in V:\n", " if (p not in H.graph['M_vertices']):\n", " r = p\n", " break\n", "\n", " # S: even level vertices of the BFS tree\n", " # N: odd level ones\n", " # path: a stucked path\n", " # match: False if the BFS tree has an unmatched vertex\n", " result = BFS_GP(H, r, V, W)\n", " path = result[0]\n", " S = result[1]\n", " N = result[2]\n", " match = result[3]\n", "\n", " # for debugging\n", " print('path: ' + str(path))\n", " print('S: ' + str(S))\n", " print('N: ' + str(N))\n", " print('match: ' + str(match))\n", " print('current matching(edges): ' + str(H.graph['M_edges']))\n", " print('current matching(vertices): ' + str(H.graph['M_vertices']))\n", "\n", " if (not match and len(path) > 1):\n", " # the path contains an unmatched vertex w in W\n", " # If such w exists, it is the last vertex of the path\n", " # replace M\n", " for i in range(len(path) - 1):\n", " if (i % 2 == 0):\n", " # the edge is from V to W, is not in M\n", " if (not path[i] in H.graph['M_vertices']):\n", " # 毎度条件をcheckするのは非効率的か\n", " H.graph['M_vertices'].append(path[i])\n", " if (not path[i+1] in H.graph['M_vertices']):\n", " H.graph['M_vertices'].append(path[i+1])\n", " # 両方向に枝を追加\n", " H.graph['M_edges'].append((path[i], path[i + 1]))\n", " H.graph['M_edges'].append((path[i + 1], path[i]))\n", " else:\n", " # the edge is from W to V, is in M\n", " if (not path[i] in H.graph['M_vertices']):\n", " H.graph['M_vertices'].append(path[i])\n", " if (not path[i+1] in H.graph['M_vertices']):\n", " H.graph['M_vertices'].append(path[i+1])\n", " H.graph['M_edges'].remove((path[i], path[i + 1]))\n", " H.graph['M_edges'].remove((path[i + 1], path[i]))\n", " else:\n", " # set diff as the reduced cost of the last edge in the stucked path\n", " # stucked pathが1点(この場合'v')からなるとき、('v','v')は存在しないのでkeyerror\n", " # stucked pathが一点からなる場合(while loopの一周目)\n", " if (len(path) == 1):\n", " diff = float('inf')\n", " for p in H.successors(path[0]):\n", " diff = min(diff, H[path[0]][p]['cost'])\n", " H.node[path[0]]['price'] += diff\n", "\n", " # for debugging\n", " print ('add ' + str(diff) + 'to ' + str(path[0]) + 's price')\n", "\n", " # stucked pathが二点以上からなる場合\n", " else:\n", " # ここの決め方が問題\n", " # (v,w), v in S, w not in N, となる枝をtightに\n", " diff = float('inf')\n", " for v in S:\n", " for w in H.successors(v):\n", " if w not in N:\n", " diff = min(diff, H[v][w]['cost'] - H.node[v]['price'] - H.node[w]['price'])\n", " for v in S:\n", " H.node[v]['price'] += diff\n", "\n", " # for debugging\n", " print ('add ' + str(diff) + ' to ' + v + 's price')\n", "\n", " for w in N:\n", " H.node[w]['price'] -= diff\n", "\n", " # for debugging\n", " print ('subtract ' + str(diff) + ' to ' + w + 's price')\n", "\n", " if exLoopCount == NUM:\n", " return 'still bugged'\n", " return H.graph['M_edges']\n" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Figure 10 のグラフを作ってみる\n", "import networkx as nx\n", "import matplotlib.pyplot as plt\n", "from collections import deque\n", "import numpy as np\n", "\n", "G = nx.DiGraph()\n", "V = ['v', 'x']\n", "W = ['w', 'y']\n", "edgelist = [('v', 'w'), ('v', 'y'), ('x', 'w'), ('x', 'y')]\n", "for edge in edgelist:\n", " G.add_edge(*edge)\n", " G.add_edge(edge[1],edge[0])\n", "G['v']['y']['cost'] = 3\n", "G['y']['v']['cost'] = 3\n", "G['v']['w']['cost'] = 2\n", "G['w']['v']['cost'] = 2\n", "G['x']['w']['cost'] = 5\n", "G['w']['x']['cost'] = 5\n", "G['x']['y']['cost'] = 7\n", "G['y']['x']['cost'] = 7" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "exLoopCount: 1\n", "before growing the BFS tree\n", "the root is: v\n", "its price: 0\n", "its candidate: []\n", "initial pp: v\n", "loopCount2: 1\n", "add v to path\n", "path: ['v']\n", "S: []\n", "N: []\n", "match: False\n", "current matching(edges): []\n", "current matching(vertices): []\n", "add 2to vs price\n", "exLoopCount: 2\n", "before growing the BFS tree\n", "the root is: v\n", "its price: 2\n", "its candidate: ['w']\n", "Grow the tree!\n", "level: 0\n", "add v to S\n", "current node: v\n", "candidates: ['w']\n", "future: deque([])\n", "distance: 0\n", "append w to future\n", "distance: 1\n", "the next node: w\n", "level: 1\n", "past: ['v']\n", "candidates: []\n", "initial pp: w\n", "loopCount2: 1\n", "add w to path\n", "go backward\n", "w\n", "1\n", "v\n", "0\n", "loopCount2: 2\n", "add v to path\n", "path: ['v', 'w']\n", "S: ['v']\n", "N: ['w']\n", "match: False\n", "current matching(edges): []\n", "current matching(vertices): []\n", "exLoopCount: 3\n", "before growing the BFS tree\n", "the root is: x\n", "its price: 0\n", "its candidate: []\n", "initial pp: x\n", "loopCount2: 1\n", "add x to path\n", "path: ['x']\n", "S: []\n", "N: []\n", "match: False\n", "current matching(edges): [('v', 'w'), ('w', 'v')]\n", "current matching(vertices): ['v', 'w']\n", "add 5to xs price\n", "exLoopCount: 4\n", "before growing the BFS tree\n", "the root is: x\n", "its price: 5\n", "its candidate: ['w']\n", "Grow the tree!\n", "level: 0\n", "add x to S\n", "current node: x\n", "candidates: ['w']\n", "future: deque([])\n", "distance: 0\n", "append w to future\n", "distance: 1\n", "the next node: w\n", "level: 1\n", "past: ['x']\n", "candidates: ['v']\n", "Grow the tree!\n", "level: 1\n", "add w to N\n", "current node: w\n", "candidates: ['v']\n", "future: deque([])\n", "distance: 1\n", "append v to future\n", "distance: 2\n", "the next node: v\n", "level: 2\n", "past: ['x', 'w']\n", "candidates: []\n", "initial pp: v\n", "loopCount2: 1\n", "add v to path\n", "go backward\n", "v\n", "2\n", "w\n", "1\n", "loopCount2: 2\n", "add w to path\n", "go backward\n", "w\n", "1\n", "x\n", "0\n", "loopCount2: 3\n", "add x to path\n", "path: ['x', 'w', 'v']\n", "S: ['x', 'v']\n", "N: ['w']\n", "match: True\n", "current matching(edges): [('v', 'w'), ('w', 'v')]\n", "current matching(vertices): ['v', 'w']\n", "add 1 to xs price\n", "add 1 to vs price\n", "subtract 1 to ws price\n", "exLoopCount: 5\n", "before growing the BFS tree\n", "the root is: x\n", "its price: 6\n", "its candidate: ['w']\n", "Grow the tree!\n", "level: 0\n", "add x to S\n", "current node: x\n", "candidates: ['w']\n", "future: deque([])\n", "distance: 0\n", "append w to future\n", "distance: 1\n", "the next node: w\n", "level: 1\n", "past: ['x']\n", "candidates: ['v']\n", "Grow the tree!\n", "level: 1\n", "add w to N\n", "current node: w\n", "candidates: ['v']\n", "future: deque([])\n", "distance: 1\n", "append v to future\n", "distance: 2\n", "the next node: v\n", "level: 2\n", "past: ['x', 'w']\n", "candidates: ['y']\n", "Grow the tree!\n", "level: 2\n", "add v to S\n", "current node: v\n", "candidates: ['y']\n", "future: deque([])\n", "distance: 2\n", "append y to future\n", "distance: 3\n", "the next node: y\n", "level: 3\n", "past: ['x', 'w', 'v']\n", "candidates: []\n", "initial pp: y\n", "loopCount2: 1\n", "add y to path\n", "go backward\n", "y\n", "3\n", "v\n", "2\n", "loopCount2: 2\n", "add v to path\n", "go backward\n", "v\n", "2\n", "w\n", "1\n", "loopCount2: 3\n", "add w to path\n", "go backward\n", "w\n", "1\n", "x\n", "0\n", "loopCount2: 4\n", "add x to path\n", "path: ['x', 'w', 'v', 'y']\n", "S: ['x', 'v']\n", "N: ['w', 'y']\n", "match: False\n", "current matching(edges): [('v', 'w'), ('w', 'v')]\n", "current matching(vertices): ['v', 'w']\n" ] }, { "data": { "text/plain": [ "[('x', 'w'), ('w', 'x'), ('v', 'y'), ('y', 'v')]" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Hungarian(G, V, W)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "とりあえず、Figure 10の例でうまく動くことは確認(もしかしたらバグがまだ残っているかも)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" }, "toc": { "nav_menu": { "height": "48px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }