{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation - Graphical model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[```igraph```](https://github.com/igraph/python-igraph) package is required since it is used for internal representation of a probabilistic graphical model. [```pyvis```](https://pyvis.readthedocs.io) is optional and it is needed for creating fancy interactive visualizations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy  as np\n",
    "import igraph as ig\n",
    "%run ./2-ImplementationFactor.ipynb\n",
    "# optional package\n",
    "import pyvis.network as net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 PGM\n",
    "\n",
    "We are going to stick with factor graphs since every Bayesian network and Markov network can be converted to this representation.\n",
    "\n",
    "As already mentioned the core of factor graph data structure is an ```igraph``` graph. Each node necessarily has the following attributes:\n",
    "\n",
    "* ```name``` attribute serves as a unique ```string``` identifier either of a factor node or a variable node\n",
    "* ```is_factor``` is a ```boolean``` indicator\n",
    "* ```factor_``` attribute is `None` for a variable node and stores a ```factor``` data structure for a factor node\n",
    "* ```rank``` attribute is `None` for a factor node and is equal to variable's rank for a variable node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class factor_graph:\n",
    "    def __init__(self):\n",
    "        self._graph = ig.Graph()\n",
    "    \n",
    "    # ----------------------- Factor node functions ---------\n",
    "    def add_factor_node(self, f_name, factor_): pass\n",
    "    def change_factor_distribution(self, f_name, factor_): pass\n",
    "    def remove_factor(self, f_name, remove_zero_degree=False): pass\n",
    "    def __create_factor_node(self, f_name, factor_): pass\n",
    "    \n",
    "    # ----------------------- Rank functions -------\n",
    "    def __check_variable_ranks(self, f_name, factor_, allowded_v_degree): pass\n",
    "    def __set_variable_ranks(self, f_name, factor_): pass\n",
    "        \n",
    "    # ----------------------- Variable node functions -------\n",
    "    def add_variable_node(self, v_name): pass\n",
    "    def remove_variable(self, v_name): pass\n",
    "    def __create_variable_node(self, v_name, rank=None): pass\n",
    "\n",
    "    # ----------------------- Info --------------------------\n",
    "    def get_node_status(self, name): pass\n",
    "    \n",
    "    # ----------------------- Graph structure ---------------\n",
    "    def get_graph(self): pass\n",
    "    def is_connected(self): pass\n",
    "    def is_loop(self): pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 Factor node functions\n",
    "\n",
    "Main operations:\n",
    "\n",
    "* ```add_factor_node```\n",
    "* ```change_factor_distribution```\n",
    "* ```remove_factor```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_factor_node(self, f_name, factor_):\n",
    "    if (self.get_node_status(f_name) != False) or (f_name in factor_.get_variables()):\n",
    "        raise Exception('Invalid factor name')\n",
    "    if type(factor_) is not factor:\n",
    "        raise Exception('Invalid factor_')\n",
    "    for v_name in factor_.get_variables():\n",
    "        if self.get_node_status(v_name) == 'factor':\n",
    "            raise Exception('Invalid factor')\n",
    "    \n",
    "    # Check ranks\n",
    "    self.__check_variable_ranks(f_name, factor_, 1)\n",
    "    # Create variables\n",
    "    for v_name in factor_.get_variables():\n",
    "        if self.get_node_status(v_name) == False:\n",
    "            self.__create_variable_node(v_name)\n",
    "    # Set ranks\n",
    "    self.__set_variable_ranks(f_name, factor_)\n",
    "    # Add node and corresponding edges\n",
    "    self.__create_factor_node(f_name, factor_)\n",
    "    \n",
    "factor_graph.add_factor_node = add_factor_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_factor_distribution(self, f_name, factor_):\n",
    "    if self.get_node_status(f_name) != 'factor':\n",
    "        raise Exception('Invalid variable name')\n",
    "    if set(factor_.get_variables()) != set(self._graph.vs[self._graph.neighbors(f_name)]['name']):\n",
    "        raise Exception('invalid factor distribution')\n",
    "    \n",
    "    # Check ranks\n",
    "    self.__check_variable_ranks(f_name, factor_, 0)\n",
    "    # Set ranks\n",
    "    self.__set_variable_ranks(f_name, factor_)\n",
    "    # Set data\n",
    "    self._graph.vs.find(name=f_name)['factor_'] = factor_\n",
    "    \n",
    "factor_graph.change_factor_distribution = change_factor_distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_factor(self, f_name, remove_zero_degree=False):\n",
    "    if self.get_node_status(f_name) != 'factor':\n",
    "        raise Exception('Invalid variable name')\n",
    "    \n",
    "    neighbors = self._graph.neighbors(f_name, mode=\"out\")\n",
    "    self._graph.delete_vertices(f_name)\n",
    "    \n",
    "    if remove_zero_degree:\n",
    "        for v_name in neighbors:\n",
    "            if self._graph.vs.find(v_name).degree() == 0:\n",
    "                self.remove_variable(v_name)\n",
    "    \n",
    "factor_graph.remove_factor = remove_factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def __create_factor_node(self, f_name, factor_):\n",
    "    # Create node\n",
    "    self._graph.add_vertex(f_name)\n",
    "    self._graph.vs.find(name=f_name)['is_factor'] = True\n",
    "    self._graph.vs.find(name=f_name)['factor_']   = factor_\n",
    "    \n",
    "    # Create corresponding edges\n",
    "    start = self._graph.vs.find(name=f_name).index\n",
    "    edge_list = [tuple([start, self._graph.vs.find(name=i).index]) for i in factor_.get_variables()]\n",
    "    self._graph.add_edges(edge_list)\n",
    "    \n",
    "factor_graph.__create_factor_node = __create_factor_node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Rank functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def __check_variable_ranks(self, f_name, factor_, allowded_v_degree):\n",
    "    for counter, v_name in enumerate(factor_.get_variables()):\n",
    "        if (self.get_node_status(v_name) == 'variable') and (not factor_.is_none()):\n",
    "            if     (self._graph.vs.find(name=v_name)['rank'] != factor_.get_shape()[counter]) \\\n",
    "               and (self._graph.vs.find(name=v_name)['rank'] != None) \\\n",
    "               and (self._graph.vs.find(v_name).degree() > allowded_v_degree):\n",
    "                raise Exception('Invalid shape of factor_')\n",
    "\n",
    "factor_graph.__check_variable_ranks = __check_variable_ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def __set_variable_ranks(self, f_name, factor_):\n",
    "    for counter, v_name in enumerate(factor_.get_variables()):\n",
    "        if factor_.is_none():\n",
    "            self._graph.vs.find(name=v_name)['rank'] = None\n",
    "        else:\n",
    "            self._graph.vs.find(name=v_name)['rank'] = factor_.get_shape()[counter]\n",
    "            \n",
    "factor_graph.__set_variable_ranks = __set_variable_ranks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 Variable node functions\n",
    "\n",
    "Main operations:\n",
    "\n",
    "* ```add_variable_node```\n",
    "* ```remove_variable```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_variable_node(self, v_name):\n",
    "    if self.get_node_status(v_name) != False:\n",
    "        raise Exception('Node already exists')\n",
    "    self.__create_variable_node(v_name)\n",
    "    \n",
    "factor_graph.add_variable_node = add_variable_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_variable(self, v_name):\n",
    "    if self.get_node_status(v_name) != 'variable':\n",
    "        raise Exception('Invalid variable name')\n",
    "    if self._graph.vs.find(v_name).degree() != 0:\n",
    "        raise Exception('Can not delete variables with degree >0')\n",
    "    self._graph.delete_vertices(self._graph.vs.find(v_name).index)        \n",
    "    \n",
    "factor_graph.remove_variable = remove_variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def __create_variable_node(self, v_name, rank=None):\n",
    "    self._graph.add_vertex(v_name)\n",
    "    self._graph.vs.find(name=v_name)['is_factor'] = False\n",
    "    self._graph.vs.find(name=v_name)['rank'] = rank\n",
    "    \n",
    "factor_graph.__create_variable_node = __create_variable_node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 Info\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_node_status(self, name):\n",
    "    if len(self._graph.vs) == 0:\n",
    "        return False\n",
    "    elif len(self._graph.vs.select(name_eq=name)) == 0:\n",
    "        return False\n",
    "    else:\n",
    "        if self._graph.vs.find(name=name)['is_factor'] == True:\n",
    "            return 'factor'\n",
    "        else:\n",
    "            return 'variable'\n",
    "    \n",
    "factor_graph.get_node_status = get_node_status"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.5 Graph structure\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_graph(self):\n",
    "    return self._graph\n",
    "\n",
    "factor_graph.get_graph = get_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_connected(self):\n",
    "    return self._graph.is_connected()\n",
    "\n",
    "factor_graph.is_connected = is_connected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_loop(self):\n",
    "    return any(self._graph.is_loop())\n",
    "\n",
    "factor_graph.is_loop = is_loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example\n",
    "\n",
    "Let's create a simple graphical model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgm_1 = factor_graph()\n",
    "pgm_1.add_factor_node('p1', factor(['x1', 'x2', 'x3']))\n",
    "pgm_1.add_factor_node('p2', factor(['x2', 'x4']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 Factor graph from string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def string2factor_graph(str_):\n",
    "    res_factor_graph = factor_graph()\n",
    "    \n",
    "    str_ = [i.split('(') for i in str_.split(')') if i != '']\n",
    "    for i in range(len(str_)):\n",
    "        str_[i][1] = str_[i][1].split(',')\n",
    "        \n",
    "    for i in str_:\n",
    "        res_factor_graph.add_factor_node(i[0], factor(i[1]))\n",
    "    \n",
    "    return res_factor_graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgm_2 = string2factor_graph('phi_1(a,b,c)phi_2(b,c,d,e)psi_3(e,c)psi_4(d)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 Plotting\n",
    "\n",
    "The default way of plotting internal graph is ```plot(x.get_graph())```. ```plot_factor_graph``` does interactive plotting using the ```pyvis``` package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_factor_graph(x):\n",
    "    graph = net.Network(notebook=True, width=\"100%\")\n",
    "    graph.toggle_physics(False)\n",
    "    \n",
    "    # Vertices\n",
    "    label = x.get_graph().vs['name']\n",
    "    color = ['#2E2E2E' if i is True else '#F2F2F2' for i in x.get_graph().vs['is_factor']]\n",
    "    graph.add_nodes(range(len(x.get_graph().vs)), label=label, color=color)\n",
    "    \n",
    "    # Edges\n",
    "    graph.add_edges(x.get_graph().get_edgelist())\n",
    "    \n",
    "    return graph.show(\"./img/3/graph.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"500px\"\n",
       "            src=\"./img/3/graph.html\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0xa2266d400>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plot_factor_graph(pgm_2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}