{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Minimum-Cost Bipartite Matching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Hungarian Algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Hungarian algorithmを実装してみた。\n",
    "効率が悪い(good pathsを探すときの辺の選び方はもっと賢くできる)ので要改良\n",
    "\n",
    "似非さんにより「パパ活アルゴリズム」と命名された。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def Initialize_H(G):\n",
    "    '''\n",
    "    Initialize a matching M and prices of all vertices.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    G = (V+W, E): a bipartite graph\n",
    "        V+W: vertices\n",
    "            price\n",
    "        E: edges\n",
    "            cost(nonnegative)\n",
    "    M: a list of the edges which are a subset of E, a graph attribute\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    H: initialized graph with zero prices and an empty matching M\n",
    "    '''\n",
    "\n",
    "    H = G.copy()\n",
    "\n",
    "    # Initialize a matching\n",
    "    # graph attributeとして定義\n",
    "    H.graph['M_edges'] = []\n",
    "    H.graph['M_vertices'] = []\n",
    "\n",
    "    # Initialize prices\n",
    "    for v in H.nodes():\n",
    "        H.node[v]['price'] = 0\n",
    "\n",
    "    return H\n",
    "\n",
    "\n",
    "def BFS_GP(G, r, V, W):\n",
    "    '''\n",
    "    Implementing BFS modified for a good path\n",
    "    Need to be modified in the future(BFS is poorly implemented.)\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    G=(V+W, E): a bipartite graph with a current matching M\n",
    "        V+W: vertices\n",
    "            p: prices\n",
    "        M: a current matching\n",
    "    r: the first unmatched vertex r of V(the root of a BFS tree)\n",
    "    Return\n",
    "    ------\n",
    "    a list (path, S, N, match)\n",
    "        path: a list of vertices representing a stucked path P, [r, ..., v]\n",
    "        S: even level vertices of the BFS tree\n",
    "        N: odd level ones\n",
    "        match: False if the search tree contains an unmatched vertex w in W\n",
    "    '''\n",
    "\n",
    "    def candidate(G, v, level, past):\n",
    "        '''return a list of vertices(say, candidates) that satisfy the condition(*):\n",
    "            1. adding (v,p) to the current path keep the path M-alternating\n",
    "            2. (v,p) is tight\n",
    "            3. p has not been reached before\n",
    "        If the length of the returned list is zero, BFS is stucked.\n",
    "\n",
    "        Paremeters\n",
    "        ----------\n",
    "        G: a bipartite graph\n",
    "        v: a current node\n",
    "        level: a current level of BFS(v lies in level i of the BFS tree)\n",
    "\n",
    "        Return\n",
    "        ------\n",
    "\n",
    "        '''\n",
    "\n",
    "        # Check if stucked or not\n",
    "        # level: odd のとき(v is in W)は、matchingに含まれる枝の中から探す必要あり\n",
    "        # level: even のとき(v is in V)は、まだ訪れていない(pastを使う)tight edgesから探す必要あり\n",
    "        if level % 2 != 0:\n",
    "            # このときcandidatesには多くとも一点しか(vとmatchしている点)含まれないはず\n",
    "            candidates = [p for p in G.neighbors(\n",
    "                v) if ((v, p) in G.graph['M_edges'])]\n",
    "        else:\n",
    "            candidates = [p for p in G.neighbors(v) if ((p not in past) and (\n",
    "                G[v][p]['cost'] - G.node[v]['price'] - G.node[p]['price'] == 0))]\n",
    "\n",
    "        return candidates\n",
    "\n",
    "    # 過去に訪れた点を格納するためのリスト\n",
    "    past = []\n",
    "    # これから訪れる点を格納するためのqueue\n",
    "    future = deque()\n",
    "    # pathを格納するためのリスト\n",
    "    path = []\n",
    "    # BFS treeを格納するためのlist, S: even level, N: odd level\n",
    "    S = []\n",
    "    N = []\n",
    "\n",
    "    # set a distance from the root\n",
    "    for p in G.nodes():\n",
    "        if (p != r):\n",
    "            G.node[p]['distance'] = float('inf')\n",
    "        else:\n",
    "            G.node[p]['distance'] = 0\n",
    "\n",
    "    # set a current node v as r\n",
    "    v = r\n",
    "    # level of a BFS tree\n",
    "    level = 0\n",
    "    # candidates, a list of nodes\n",
    "    candidates = candidate(G, v, level, past)\n",
    "\n",
    "    # for debugging\n",
    "    print('before growing the BFS tree')\n",
    "    print('the root is: ' + v)\n",
    "    print('its price: ' + str(G.node[v]['price']))\n",
    "    print('its candidate: ' + str(candidates))\n",
    "\n",
    "    while(len(candidates) != 0 and not (set(V) <= set(G.graph['M_vertices']))):\n",
    "        # level = 0, or unstucked and there is a unmatched vertex of V\n",
    "        # v: current vertex\n",
    "        # level: current level\n",
    "\n",
    "        # for debugging\n",
    "        print('Grow the tree!')\n",
    "\n",
    "        past.append(v)\n",
    "        if (level % 2 == 0):\n",
    "            S.append(v)\n",
    "\n",
    "            # for debugging\n",
    "            print('level: ' + str(level))\n",
    "            print('add ' + v + ' to S')\n",
    "        else:\n",
    "            N.append(v)\n",
    "\n",
    "            # for debugging\n",
    "            print('level: ' + str(level))\n",
    "            print('add ' + v + ' to N')\n",
    "\n",
    "        level += 1\n",
    "\n",
    "        # for debugging\n",
    "        print('current node: ' + v)\n",
    "        print('candidates: ' + str(candidates))\n",
    "        print('future: ' + str(future))\n",
    "        print('distance: ' + str(G.node[v]['distance']))\n",
    "\n",
    "        for p in candidates:\n",
    "            if p not in future:\n",
    "                future.append(p)\n",
    "                print ('append ' + p + ' to future')\n",
    "                G.node[p]['distance'] = G.node[v]['distance'] + 1\n",
    "                print ('distance: ' + str(G.node[p]['distance']))\n",
    "\n",
    "        if(len(future) == 0):\n",
    "            break\n",
    "        else:\n",
    "            v = future.popleft()\n",
    "            print('the next node: ' + v)\n",
    "            print('level: ' + str(level))\n",
    "            print('past: ' + str(past))\n",
    "            candidates = candidate(G, v, level, past)\n",
    "            print('candidates: ' + str(candidates))\n",
    "\n",
    "            if (len(candidates) == 0):\n",
    "                if (level % 2 == 0):\n",
    "                    S.append(v)\n",
    "                else:\n",
    "                    N.append(v)\n",
    "\n",
    "    # After terminating the while loop, BFS search is stucked\n",
    "    # First, check if v is another unmatched vertex or not\n",
    "    match = False\n",
    "    if (v in G.graph['M_vertices']):\n",
    "        # In case that the last vertex v is matched\n",
    "        match = True\n",
    "\n",
    "    # Construct a good path from r to v\n",
    "    # 以下、sからeへのpathが存在する場合\n",
    "    # 終点から遡ってpathを形成する\n",
    "    pp = v\n",
    "\n",
    "    # for debugging\n",
    "    loopCount2 = 0\n",
    "    print ('initial pp: ' + pp)\n",
    "\n",
    "    while (1):\n",
    "        loopCount2 += 1\n",
    "        print ('loopCount2: ' + str(loopCount2))\n",
    "        if loopCount2 == 20:\n",
    "            return 'bugbug'\n",
    "\n",
    "        path.insert(0, pp)\n",
    "        print('add ' + pp +' to path')\n",
    "        if pp == r:\n",
    "            break\n",
    "\n",
    "        pred = G.predecessors(pp)\n",
    "\n",
    "        for p in pred:\n",
    "            if (G.node[p]['distance'] == G.node[pp]['distance'] - 1\n",
    "                and (G[p][pp]['cost'] - G.node[p]['price'] - G.node[pp]['price'] == 0)):\n",
    "                print('go backward')\n",
    "                print(pp)\n",
    "                print(G.node[pp]['distance'])\n",
    "                pp = p\n",
    "                print(pp)\n",
    "                print(G.node[pp]['distance'])\n",
    "                break\n",
    "\n",
    "    return [path, S, N, match]\n",
    "\n",
    "\n",
    "def Hungarian(G, V, W):\n",
    "    '''Solving min-cost bipartite matching problem\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    G=(V+W, E): a bipartite graph\n",
    "        V+W: vertices, |V| = n\n",
    "            prices\n",
    "        E: edges\n",
    "            cost>=0\n",
    "            the reduced cost of (v,w) =  cost of (v,w) - price of v - price of w\n",
    "        M: matching\n",
    "    Invariants:\n",
    "        1. the reduced costs of all edges in G >= 0\n",
    "        2. the reduced costs of all edges in M = 0(tight)\n",
    "\n",
    "    Return\n",
    "    ------\n",
    "    the min-cost matching M of G\n",
    "    '''\n",
    "\n",
    "    H = Initialize_H(G)\n",
    "\n",
    "    # for debugging\n",
    "    exLoopCount = 0\n",
    "    NUM = 30\n",
    "\n",
    "    while (len(H.graph['M_vertices']) != len(V) + len(W)):\n",
    "        # a current matching M is not a perfect matching\n",
    "\n",
    "        # for debugging\n",
    "        exLoopCount += 1\n",
    "        print('exLoopCount: ' + str(exLoopCount))\n",
    "        if (exLoopCount == NUM):\n",
    "            break\n",
    "\n",
    "        # pick the first unmatched node r of V\n",
    "        for p in V:\n",
    "            if (p not in H.graph['M_vertices']):\n",
    "                r = p\n",
    "                break\n",
    "\n",
    "        # S: even level vertices of the BFS tree\n",
    "        # N: odd level ones\n",
    "        # path: a stucked path\n",
    "        # match: False if the BFS tree has an unmatched vertex\n",
    "        result = BFS_GP(H, r, V, W)\n",
    "        path = result[0]\n",
    "        S = result[1]\n",
    "        N = result[2]\n",
    "        match = result[3]\n",
    "\n",
    "        # for debugging\n",
    "        print('path: ' + str(path))\n",
    "        print('S: ' + str(S))\n",
    "        print('N: ' + str(N))\n",
    "        print('match: ' + str(match))\n",
    "        print('current matching(edges): ' + str(H.graph['M_edges']))\n",
    "        print('current matching(vertices): ' + str(H.graph['M_vertices']))\n",
    "\n",
    "        if (not match and len(path) > 1):\n",
    "            # the path contains an unmatched vertex w in W\n",
    "            # If such w exists, it is the last vertex of the path\n",
    "            # replace M\n",
    "            for i in range(len(path) - 1):\n",
    "                if (i % 2 == 0):\n",
    "                    # the edge is from V to W, is not in M\n",
    "                    if (not path[i] in H.graph['M_vertices']):\n",
    "                        # 毎度条件をcheckするのは非効率的か\n",
    "                        H.graph['M_vertices'].append(path[i])\n",
    "                    if (not path[i+1] in H.graph['M_vertices']):\n",
    "                        H.graph['M_vertices'].append(path[i+1])\n",
    "                    # 両方向に枝を追加\n",
    "                    H.graph['M_edges'].append((path[i], path[i + 1]))\n",
    "                    H.graph['M_edges'].append((path[i + 1], path[i]))\n",
    "                else:\n",
    "                    # the edge is from W to V, is in M\n",
    "                    if (not path[i] in H.graph['M_vertices']):\n",
    "                        H.graph['M_vertices'].append(path[i])\n",
    "                    if (not path[i+1] in H.graph['M_vertices']):\n",
    "                        H.graph['M_vertices'].append(path[i+1])\n",
    "                    H.graph['M_edges'].remove((path[i], path[i + 1]))\n",
    "                    H.graph['M_edges'].remove((path[i + 1], path[i]))\n",
    "        else:\n",
    "            # set diff as the reduced cost of the last edge in the stucked path\n",
    "            # stucked pathが1点(この場合'v')からなるとき、('v','v')は存在しないのでkeyerror\n",
    "            # stucked pathが一点からなる場合(while loopの一周目)\n",
    "            if (len(path) == 1):\n",
    "                diff = float('inf')\n",
    "                for p in H.successors(path[0]):\n",
    "                    diff = min(diff, H[path[0]][p]['cost'])\n",
    "                H.node[path[0]]['price'] += diff\n",
    "\n",
    "                # for debugging\n",
    "                print ('add ' + str(diff) + 'to ' + str(path[0]) + 's price')\n",
    "\n",
    "            # stucked pathが二点以上からなる場合\n",
    "            else:\n",
    "                # ここの決め方が問題\n",
    "                # (v,w), v in S, w not in N, となる枝をtightに\n",
    "                diff = float('inf')\n",
    "                for v in S:\n",
    "                    for w in H.successors(v):\n",
    "                        if w not in N:\n",
    "                            diff = min(diff, H[v][w]['cost'] - H.node[v]['price'] - H.node[w]['price'])\n",
    "            for v in S:\n",
    "                H.node[v]['price'] += diff\n",
    "\n",
    "                # for debugging\n",
    "                print ('add ' + str(diff) + ' to ' + v + 's price')\n",
    "\n",
    "            for w in N:\n",
    "                H.node[w]['price'] -= diff\n",
    "\n",
    "                # for debugging\n",
    "                print ('subtract ' + str(diff) + ' to ' + w + 's price')\n",
    "\n",
    "    if exLoopCount == NUM:\n",
    "        return 'still bugged'\n",
    "    return H.graph['M_edges']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Figure 10 のグラフを作ってみる\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "\n",
    "G = nx.DiGraph()\n",
    "V = ['v', 'x']\n",
    "W = ['w', 'y']\n",
    "edgelist = [('v', 'w'), ('v', 'y'), ('x', 'w'), ('x', 'y')]\n",
    "for edge in edgelist:\n",
    "    G.add_edge(*edge)\n",
    "    G.add_edge(edge[1],edge[0])\n",
    "G['v']['y']['cost'] = 3\n",
    "G['y']['v']['cost'] = 3\n",
    "G['v']['w']['cost'] = 2\n",
    "G['w']['v']['cost'] = 2\n",
    "G['x']['w']['cost'] = 5\n",
    "G['w']['x']['cost'] = 5\n",
    "G['x']['y']['cost'] = 7\n",
    "G['y']['x']['cost'] = 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "exLoopCount: 1\n",
      "before growing the BFS tree\n",
      "the root is: v\n",
      "its price: 0\n",
      "its candidate: []\n",
      "initial pp: v\n",
      "loopCount2: 1\n",
      "add v to path\n",
      "path: ['v']\n",
      "S: []\n",
      "N: []\n",
      "match: False\n",
      "current matching(edges): []\n",
      "current matching(vertices): []\n",
      "add 2to vs price\n",
      "exLoopCount: 2\n",
      "before growing the BFS tree\n",
      "the root is: v\n",
      "its price: 2\n",
      "its candidate: ['w']\n",
      "Grow the tree!\n",
      "level: 0\n",
      "add v to S\n",
      "current node: v\n",
      "candidates: ['w']\n",
      "future: deque([])\n",
      "distance: 0\n",
      "append w to future\n",
      "distance: 1\n",
      "the next node: w\n",
      "level: 1\n",
      "past: ['v']\n",
      "candidates: []\n",
      "initial pp: w\n",
      "loopCount2: 1\n",
      "add w to path\n",
      "go backward\n",
      "w\n",
      "1\n",
      "v\n",
      "0\n",
      "loopCount2: 2\n",
      "add v to path\n",
      "path: ['v', 'w']\n",
      "S: ['v']\n",
      "N: ['w']\n",
      "match: False\n",
      "current matching(edges): []\n",
      "current matching(vertices): []\n",
      "exLoopCount: 3\n",
      "before growing the BFS tree\n",
      "the root is: x\n",
      "its price: 0\n",
      "its candidate: []\n",
      "initial pp: x\n",
      "loopCount2: 1\n",
      "add x to path\n",
      "path: ['x']\n",
      "S: []\n",
      "N: []\n",
      "match: False\n",
      "current matching(edges): [('v', 'w'), ('w', 'v')]\n",
      "current matching(vertices): ['v', 'w']\n",
      "add 5to xs price\n",
      "exLoopCount: 4\n",
      "before growing the BFS tree\n",
      "the root is: x\n",
      "its price: 5\n",
      "its candidate: ['w']\n",
      "Grow the tree!\n",
      "level: 0\n",
      "add x to S\n",
      "current node: x\n",
      "candidates: ['w']\n",
      "future: deque([])\n",
      "distance: 0\n",
      "append w to future\n",
      "distance: 1\n",
      "the next node: w\n",
      "level: 1\n",
      "past: ['x']\n",
      "candidates: ['v']\n",
      "Grow the tree!\n",
      "level: 1\n",
      "add w to N\n",
      "current node: w\n",
      "candidates: ['v']\n",
      "future: deque([])\n",
      "distance: 1\n",
      "append v to future\n",
      "distance: 2\n",
      "the next node: v\n",
      "level: 2\n",
      "past: ['x', 'w']\n",
      "candidates: []\n",
      "initial pp: v\n",
      "loopCount2: 1\n",
      "add v to path\n",
      "go backward\n",
      "v\n",
      "2\n",
      "w\n",
      "1\n",
      "loopCount2: 2\n",
      "add w to path\n",
      "go backward\n",
      "w\n",
      "1\n",
      "x\n",
      "0\n",
      "loopCount2: 3\n",
      "add x to path\n",
      "path: ['x', 'w', 'v']\n",
      "S: ['x', 'v']\n",
      "N: ['w']\n",
      "match: True\n",
      "current matching(edges): [('v', 'w'), ('w', 'v')]\n",
      "current matching(vertices): ['v', 'w']\n",
      "add 1 to xs price\n",
      "add 1 to vs price\n",
      "subtract 1 to ws price\n",
      "exLoopCount: 5\n",
      "before growing the BFS tree\n",
      "the root is: x\n",
      "its price: 6\n",
      "its candidate: ['w']\n",
      "Grow the tree!\n",
      "level: 0\n",
      "add x to S\n",
      "current node: x\n",
      "candidates: ['w']\n",
      "future: deque([])\n",
      "distance: 0\n",
      "append w to future\n",
      "distance: 1\n",
      "the next node: w\n",
      "level: 1\n",
      "past: ['x']\n",
      "candidates: ['v']\n",
      "Grow the tree!\n",
      "level: 1\n",
      "add w to N\n",
      "current node: w\n",
      "candidates: ['v']\n",
      "future: deque([])\n",
      "distance: 1\n",
      "append v to future\n",
      "distance: 2\n",
      "the next node: v\n",
      "level: 2\n",
      "past: ['x', 'w']\n",
      "candidates: ['y']\n",
      "Grow the tree!\n",
      "level: 2\n",
      "add v to S\n",
      "current node: v\n",
      "candidates: ['y']\n",
      "future: deque([])\n",
      "distance: 2\n",
      "append y to future\n",
      "distance: 3\n",
      "the next node: y\n",
      "level: 3\n",
      "past: ['x', 'w', 'v']\n",
      "candidates: []\n",
      "initial pp: y\n",
      "loopCount2: 1\n",
      "add y to path\n",
      "go backward\n",
      "y\n",
      "3\n",
      "v\n",
      "2\n",
      "loopCount2: 2\n",
      "add v to path\n",
      "go backward\n",
      "v\n",
      "2\n",
      "w\n",
      "1\n",
      "loopCount2: 3\n",
      "add w to path\n",
      "go backward\n",
      "w\n",
      "1\n",
      "x\n",
      "0\n",
      "loopCount2: 4\n",
      "add x to path\n",
      "path: ['x', 'w', 'v', 'y']\n",
      "S: ['x', 'v']\n",
      "N: ['w', 'y']\n",
      "match: False\n",
      "current matching(edges): [('v', 'w'), ('w', 'v')]\n",
      "current matching(vertices): ['v', 'w']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[('x', 'w'), ('w', 'x'), ('v', 'y'), ('y', 'v')]"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Hungarian(G, V, W)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "とりあえず、Figure 10の例でうまく動くことは確認(もしかしたらバグがまだ残っているかも)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "nav_menu": {
    "height": "48px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}